File size: 7,399 Bytes
4ec33a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cd71e4
 
4ec33a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cd71e4
 
 
 
 
 
 
 
4ec33a6
 
 
 
9cd71e4
 
 
 
 
 
 
 
4ec33a6
 
 
 
9cd71e4
 
 
 
 
 
 
 
4ec33a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import torch
import torch.nn.functional as F
from tqdm import trange
import time
from tokenxxx import *
from main import *
from duckduckgo_search import DDGS

try:
    END_OF_TEXT_TOKEN
except NameError:
    END_OF_TEXT_TOKEN = ""
try:
    SYSTEM_PROMPT
except NameError:
    SYSTEM_PROMPT = "Sistema: Proporcione respuestas ultra rápidas, coherentes, similares, precisas y con sentido, con razonamiento lógico y profundo."
try:
    MAX_XDD
except NameError:
    MAX_XDD = 5
try:
    codegen_model
except NameError:
    codegen_model = None
try:
    codegen_tokenizer
except NameError:
    codegen_tokenizer = None
try:
    summarization_model
except NameError:
    summarization_model = None
try:
    summarization_tokenizer
except NameError:
    summarization_tokenizer = None
try:
    model_gpt2
except NameError:
    model_gpt2 = None
try:
    enc
except NameError:
    enc = None
try:
    device
except NameError:
    device = "cpu"

if torch.device(device).type == "cuda":
    torch.backends.cudnn.benchmark = True

MAX_GENERATION_LENGTH = 512

def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    top_k = min(top_k, logits.size(-1))
    if top_k > 0:
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., [-1]]
        logits[indices_to_remove] = filter_value
    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

def _generate_sequence(model_call, context_tensor, generated, decode_fn, end_token_condition, temperature, top_k, top_p, repetition_penalty, max_length):
    past_key_values = None
    last_token = None
    repetition_count = 0
    for _ in range(max_length):
        try:
            outputs = model_call(context_tensor, past_key_values)
        except Exception as e:
            yield "<ERROR:" + str(e) + ">"
            yield "<END_STREAM>"
            return
        next_token_logits = outputs[0][:, -1, :] / temperature
        past_key_values = outputs[1]
        for token_index in set(generated):
            next_token_logits[0, token_index] /= repetition_penalty
        filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
        if temperature == 0:
            next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
        else:
            next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
        token_id = next_token.tolist()[0][0]
        if token_id == last_token:
            repetition_count += 1
        else:
            repetition_count = 0
            last_token = token_id
        if repetition_count >= 10:
            yield "<END_STREAM>"
            return
        generated.append(token_id)
        token_decoded = decode_fn(token_id)
        yield token_decoded
        if end_token_condition(token_id):
            yield "<END_STREAM>"
            return

def sample_sequence(prompt, model, enc, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
    context_tokens = enc.encode(prompt)
    context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
    return _generate_sequence(
        lambda ct, past: model(ct, past_key_values=past),
        context_tensor,
        list(context_tokens),
        lambda token: enc.decode([token]),
        lambda token: token == enc.encoder[END_OF_TEXT_TOKEN],
        temperature, top_k, top_p, repetition_penalty, max_length
    )

def sample_sequence_codegen(prompt, model, tokenizer, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
    context_tokens = tokenizer.encode(prompt)
    context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
    return _generate_sequence(
        lambda ct, past: model(input_ids=ct, past_key_values=past, labels=None),
        context_tensor,
        list(context_tokens),
        lambda token: tokenizer.decode([token]),
        lambda token: token == 50256,
        temperature, top_k, top_p, repetition_penalty, max_length
    )

def summarize_text(text):
    if summarization_model and summarization_tokenizer:
        input_ids = summarization_tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=1024).to(device)
        summary_ids = summarization_model.generate(
            input_ids,
            max_length=150,
            min_length=40,
            length_penalty=2.0,
            num_beams=4,
            early_stopping=True
        )
        return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return text[:300] + "..." if len(text) > 300 else text

def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty, prev_context=""):
    initial_prompt = SYSTEM_PROMPT + "\n\nUser: " + text_input + "\nAssistant:"
    reasoning_prompt = prev_context if prev_context else initial_prompt
    ddgs = DDGS()
    search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)]
    if search_results:
        reasoning_prompt += "\nWeb Search Results:\n"
        for result in search_results:
            reasoning_prompt += "- " + result['body'] + "\n"
        reasoning_prompt += "\n"
    if "code" in text_input.lower() or "program" in text_input.lower():
        model_type = "code"
    elif "summarize" in text_input.lower() or "summary" in text_input.lower():
        model_type = "summarize"
    elif model_gpt2 and enc:
        model_type = "gpt2"
    else:
        yield "<ERROR: No se encontró un modelo adecuado>"
        yield "<END_STREAM>"
        return
    if model_type == "summarize":
        if summarization_model:
            summary = summarize_text(text_input)
            yield "SUMMARY_TEXT:" + summary
            yield "<END_STREAM>"
            return
    accumulated_text = ""
    current_context = reasoning_prompt
    overlap = 256
    while True:
        if model_type == "code":
            generator = sample_sequence_codegen(current_context, codegen_model, codegen_tokenizer, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device)
        elif model_type == "gpt2":
            generator = sample_sequence(current_context, model_gpt2, enc, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device)
        chunk_text = ""
        finished = False
        for token in generator:
            if token == "<END_STREAM>":
                finished = True
                break
            chunk_text += token
        if accumulated_text:
            overlap_text = accumulated_text[-overlap:]
            if chunk_text.startswith(overlap_text):
                chunk_text = chunk_text[len(overlap_text):]
        accumulated_text += chunk_text
        yield chunk_text
        if finished:
            yield "<END_STREAM>"
            break
        current_context = accumulated_text[-overlap:] if len(accumulated_text) > overlap else accumulated_text