Spaces:
Running
Running
File size: 7,399 Bytes
4ec33a6 9cd71e4 4ec33a6 9cd71e4 4ec33a6 9cd71e4 4ec33a6 9cd71e4 4ec33a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import torch
import torch.nn.functional as F
from tqdm import trange
import time
from tokenxxx import *
from main import *
from duckduckgo_search import DDGS
try:
END_OF_TEXT_TOKEN
except NameError:
END_OF_TEXT_TOKEN = ""
try:
SYSTEM_PROMPT
except NameError:
SYSTEM_PROMPT = "Sistema: Proporcione respuestas ultra rápidas, coherentes, similares, precisas y con sentido, con razonamiento lógico y profundo."
try:
MAX_XDD
except NameError:
MAX_XDD = 5
try:
codegen_model
except NameError:
codegen_model = None
try:
codegen_tokenizer
except NameError:
codegen_tokenizer = None
try:
summarization_model
except NameError:
summarization_model = None
try:
summarization_tokenizer
except NameError:
summarization_tokenizer = None
try:
model_gpt2
except NameError:
model_gpt2 = None
try:
enc
except NameError:
enc = None
try:
device
except NameError:
device = "cpu"
if torch.device(device).type == "cuda":
torch.backends.cudnn.benchmark = True
MAX_GENERATION_LENGTH = 512
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
top_k = min(top_k, logits.size(-1))
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., [-1]]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
def _generate_sequence(model_call, context_tensor, generated, decode_fn, end_token_condition, temperature, top_k, top_p, repetition_penalty, max_length):
past_key_values = None
last_token = None
repetition_count = 0
for _ in range(max_length):
try:
outputs = model_call(context_tensor, past_key_values)
except Exception as e:
yield "<ERROR:" + str(e) + ">"
yield "<END_STREAM>"
return
next_token_logits = outputs[0][:, -1, :] / temperature
past_key_values = outputs[1]
for token_index in set(generated):
next_token_logits[0, token_index] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
if temperature == 0:
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
else:
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
token_id = next_token.tolist()[0][0]
if token_id == last_token:
repetition_count += 1
else:
repetition_count = 0
last_token = token_id
if repetition_count >= 10:
yield "<END_STREAM>"
return
generated.append(token_id)
token_decoded = decode_fn(token_id)
yield token_decoded
if end_token_condition(token_id):
yield "<END_STREAM>"
return
def sample_sequence(prompt, model, enc, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
context_tokens = enc.encode(prompt)
context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
return _generate_sequence(
lambda ct, past: model(ct, past_key_values=past),
context_tensor,
list(context_tokens),
lambda token: enc.decode([token]),
lambda token: token == enc.encoder[END_OF_TEXT_TOKEN],
temperature, top_k, top_p, repetition_penalty, max_length
)
def sample_sequence_codegen(prompt, model, tokenizer, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
context_tokens = tokenizer.encode(prompt)
context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
return _generate_sequence(
lambda ct, past: model(input_ids=ct, past_key_values=past, labels=None),
context_tensor,
list(context_tokens),
lambda token: tokenizer.decode([token]),
lambda token: token == 50256,
temperature, top_k, top_p, repetition_penalty, max_length
)
def summarize_text(text):
if summarization_model and summarization_tokenizer:
input_ids = summarization_tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=1024).to(device)
summary_ids = summarization_model.generate(
input_ids,
max_length=150,
min_length=40,
length_penalty=2.0,
num_beams=4,
early_stopping=True
)
return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return text[:300] + "..." if len(text) > 300 else text
def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty, prev_context=""):
initial_prompt = SYSTEM_PROMPT + "\n\nUser: " + text_input + "\nAssistant:"
reasoning_prompt = prev_context if prev_context else initial_prompt
ddgs = DDGS()
search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)]
if search_results:
reasoning_prompt += "\nWeb Search Results:\n"
for result in search_results:
reasoning_prompt += "- " + result['body'] + "\n"
reasoning_prompt += "\n"
if "code" in text_input.lower() or "program" in text_input.lower():
model_type = "code"
elif "summarize" in text_input.lower() or "summary" in text_input.lower():
model_type = "summarize"
elif model_gpt2 and enc:
model_type = "gpt2"
else:
yield "<ERROR: No se encontró un modelo adecuado>"
yield "<END_STREAM>"
return
if model_type == "summarize":
if summarization_model:
summary = summarize_text(text_input)
yield "SUMMARY_TEXT:" + summary
yield "<END_STREAM>"
return
accumulated_text = ""
current_context = reasoning_prompt
overlap = 256
while True:
if model_type == "code":
generator = sample_sequence_codegen(current_context, codegen_model, codegen_tokenizer, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device)
elif model_type == "gpt2":
generator = sample_sequence(current_context, model_gpt2, enc, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device)
chunk_text = ""
finished = False
for token in generator:
if token == "<END_STREAM>":
finished = True
break
chunk_text += token
if accumulated_text:
overlap_text = accumulated_text[-overlap:]
if chunk_text.startswith(overlap_text):
chunk_text = chunk_text[len(overlap_text):]
accumulated_text += chunk_text
yield chunk_text
if finished:
yield "<END_STREAM>"
break
current_context = accumulated_text[-overlap:] if len(accumulated_text) > overlap else accumulated_text
|