File size: 6,696 Bytes
7b74407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch, torch.nn.functional as F
from tqdm import trange
import time
from tokenxxx import *
from main import *
from duckduckgo_search import DDGS

try: END_OF_TEXT_TOKEN
except NameError: END_OF_TEXT_TOKEN = ""
try: SYSTEM_PROMPT
except NameError: SYSTEM_PROMPT = "Sistema: Proporcione respuestas ultra rápidas, coherentes, similares, precisas y con sentido, con razonamiento lógico y profundo."
try: MAX_XDD
except NameError: MAX_XDD = 5
try: codegen_model
except NameError: codegen_model = None
try: codegen_tokenizer
except NameError: codegen_tokenizer = None
try: summarization_model
except NameError: summarization_model = None
try: summarization_tokenizer
except NameError: summarization_tokenizer = None
try: model_gpt2
except NameError: model_gpt2 = None
try: enc
except NameError: enc = None
try: device
except NameError: device = "cpu"

if torch.device(device).type == "cuda": torch.backends.cudnn.benchmark = True
MAX_GENERATION_LENGTH = 512

def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    top_k = min(top_k, logits.size(-1)); 
    if top_k > 0: indices_to_remove = logits < torch.topk(logits, top_k)[0][..., [-1]]; logits[indices_to_remove] = filter_value
    if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1); cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_remove = cumulative_probs > top_p; sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone(); sorted_indices_to_remove[..., 0] = 0; indices_to_remove = sorted_indices[sorted_indices_to_remove]; logits[indices_to_remove] = filter_value; return logits

def _generate_sequence(model_call, context_tensor, generated, decode_fn, end_token_condition, temperature, top_k, top_p, repetition_penalty, max_length):
    past_key_values = None; last_token = None; repetition_count = 0
    for _ in range(max_length):
        try: outputs = model_call(context_tensor, past_key_values)
        except Exception as e: yield "<ERROR:" + str(e) + ">"; yield "<END_STREAM>"; return
        next_token_logits = outputs[0][:, -1, :] / temperature; past_key_values = outputs[1]
        for token_index in set(generated): next_token_logits[0, token_index] /= repetition_penalty
        filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
        if temperature == 0: next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
        else: next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
        token_id = next_token.tolist()[0][0]
        if token_id == last_token: repetition_count += 1
        else: repetition_count = 0; last_token = token_id
        if repetition_count >= 10: yield "<END_STREAM>"; return
        generated.append(token_id); token_decoded = decode_fn(token_id); yield token_decoded
        if end_token_condition(token_id): yield "<END_STREAM>"; return

def sample_sequence(prompt, model, enc, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
    context_tokens = enc.encode(prompt); context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
    return _generate_sequence(lambda ct, past: model(ct, past_key_values=past), context_tensor, list(context_tokens), lambda token: enc.decode([token]), lambda token: token == enc.encoder[END_OF_TEXT_TOKEN], temperature, top_k, top_p, repetition_penalty, max_length)

def sample_sequence_codegen(prompt, model, tokenizer, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
    context_tokens = tokenizer.encode(prompt); context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
    return _generate_sequence(lambda ct, past: model(input_ids=ct, past_key_values=past, labels=None), context_tensor, list(context_tokens), lambda token: tokenizer.decode([token]), lambda token: token == 50256, temperature, top_k, top_p, repetition_penalty, max_length)

def summarize_text(text):
    if summarization_model and summarization_tokenizer:
        input_ids = summarization_tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=1024).to(device); summary_ids = summarization_model.generate(input_ids, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
        return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return text[:300] + "..." if len(text) > 300 else text

def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty, prev_context=""):
    initial_prompt = SYSTEM_PROMPT + "\n\nUser: " + text_input + "\nAssistant:"; reasoning_prompt = prev_context if prev_context else initial_prompt; ddgs = DDGS()
    search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)]; 
    if search_results: reasoning_prompt += "\nWeb Search Results:\n"; 
    for result in search_results: reasoning_prompt += "- " + result['body'] + "\n"
    reasoning_prompt += "\n"
    if "code" in text_input.lower() or "program" in text_input.lower(): model_type = "code"
    elif "summarize" in text_input.lower() or "summary" in text_input.lower(): model_type = "summarize"
    elif model_gpt2 and enc: model_type = "gpt2"
    else: yield "<ERROR: No se encontró un modelo adecuado>"; yield "<END_STREAM>"; return
    if model_type == "summarize":
        if summarization_model: summary = summarize_text(text_input); yield "SUMMARY_TEXT:" + summary; yield "<END_STREAM>"; return
    accumulated_text = ""; current_context = reasoning_prompt; overlap = 256
    while True:
        if model_type == "code": generator = sample_sequence_codegen(current_context, codegen_model, codegen_tokenizer, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device)
        elif model_type == "gpt2": generator = sample_sequence(current_context, model_gpt2, enc, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device)
        chunk_text = ""; finished = False
        for token in generator:
            if token == "<END_STREAM>": finished = True; break
            chunk_text += token
        if accumulated_text: overlap_text = accumulated_text[-overlap:]; 
        if chunk_text.startswith(overlap_text): chunk_text = chunk_text[len(overlap_text):]
        accumulated_text += chunk_text; yield chunk_text
        if finished: yield "<END_STREAM>"; break
        current_context = accumulated_text[-overlap:] if len(accumulated_text) > overlap else accumulated_text