Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
from tqdm import trange | |
import time | |
from tokenxxx import * | |
from main import * | |
from duckduckgo_search import DDGS | |
try: | |
END_OF_TEXT_TOKEN | |
except NameError: | |
END_OF_TEXT_TOKEN = "" | |
try: | |
SYSTEM_PROMPT | |
except NameError: | |
SYSTEM_PROMPT = "Sistema: Proporcione respuestas ultra rápidas, coherentes, similares, precisas y con sentido, con razonamiento lógico y profundo." | |
try: | |
MAX_XDD | |
except NameError: | |
MAX_XDD = 5 | |
try: | |
codegen_model | |
except NameError: | |
codegen_model = None | |
try: | |
codegen_tokenizer | |
except NameError: | |
codegen_tokenizer = None | |
try: | |
summarization_model | |
except NameError: | |
summarization_model = None | |
try: | |
summarization_tokenizer | |
except NameError: | |
summarization_tokenizer = None | |
try: | |
model_gpt2 | |
except NameError: | |
model_gpt2 = None | |
try: | |
enc | |
except NameError: | |
enc = None | |
try: | |
device | |
except NameError: | |
device = "cpu" | |
if torch.device(device).type == "cuda": | |
torch.backends.cudnn.benchmark = True | |
MAX_GENERATION_LENGTH = 512 | |
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |
top_k = min(top_k, logits.size(-1)) | |
if top_k > 0: | |
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., [-1]] | |
logits[indices_to_remove] = filter_value | |
if top_p > 0.0: | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
sorted_indices_to_remove = cumulative_probs > top_p | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
logits[indices_to_remove] = filter_value | |
return logits | |
def _generate_sequence(model_call, context_tensor, generated, decode_fn, end_token_condition, temperature, top_k, top_p, repetition_penalty, max_length): | |
past_key_values = None | |
last_token = None | |
repetition_count = 0 | |
for _ in range(max_length): | |
try: | |
outputs = model_call(context_tensor, past_key_values) | |
except Exception as e: | |
yield "<ERROR:" + str(e) + ">" | |
yield "<END_STREAM>" | |
return | |
next_token_logits = outputs[0][:, -1, :] / temperature | |
past_key_values = outputs[1] | |
for token_index in set(generated): | |
next_token_logits[0, token_index] /= repetition_penalty | |
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) | |
if temperature == 0: | |
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0) | |
else: | |
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) | |
token_id = next_token.tolist()[0][0] | |
if token_id == last_token: | |
repetition_count += 1 | |
else: | |
repetition_count = 0 | |
last_token = token_id | |
if repetition_count >= 10: | |
yield "<END_STREAM>" | |
return | |
generated.append(token_id) | |
token_decoded = decode_fn(token_id) | |
yield token_decoded | |
if end_token_condition(token_id): | |
yield "<END_STREAM>" | |
return | |
def sample_sequence(prompt, model, enc, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"): | |
context_tokens = enc.encode(prompt) | |
context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device) | |
return _generate_sequence( | |
lambda ct, past: model(ct, past_key_values=past), | |
context_tensor, | |
list(context_tokens), | |
lambda token: enc.decode([token]), | |
lambda token: token == enc.encoder[END_OF_TEXT_TOKEN], | |
temperature, top_k, top_p, repetition_penalty, max_length | |
) | |
def sample_sequence_codegen(prompt, model, tokenizer, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"): | |
context_tokens = tokenizer.encode(prompt) | |
context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device) | |
return _generate_sequence( | |
lambda ct, past: model(input_ids=ct, past_key_values=past, labels=None), | |
context_tensor, | |
list(context_tokens), | |
lambda token: tokenizer.decode([token]), | |
lambda token: token == 50256, | |
temperature, top_k, top_p, repetition_penalty, max_length | |
) | |
def summarize_text(text): | |
if summarization_model and summarization_tokenizer: | |
input_ids = summarization_tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=1024).to(device) | |
summary_ids = summarization_model.generate( | |
input_ids, | |
max_length=150, | |
min_length=40, | |
length_penalty=2.0, | |
num_beams=4, | |
early_stopping=True | |
) | |
return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return text[:300] + "..." if len(text) > 300 else text | |
def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty, prev_context=""): | |
initial_prompt = SYSTEM_PROMPT + "\n\nUser: " + text_input + "\nAssistant:" | |
reasoning_prompt = prev_context if prev_context else initial_prompt | |
ddgs = DDGS() | |
search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)] | |
if search_results: | |
reasoning_prompt += "\nWeb Search Results:\n" | |
for result in search_results: | |
reasoning_prompt += "- " + result['body'] + "\n" | |
reasoning_prompt += "\n" | |
if "code" in text_input.lower() or "program" in text_input.lower(): | |
model_type = "code" | |
elif "summarize" in text_input.lower() or "summary" in text_input.lower(): | |
model_type = "summarize" | |
elif model_gpt2 and enc: | |
model_type = "gpt2" | |
else: | |
yield "<ERROR: No se encontró un modelo adecuado>" | |
yield "<END_STREAM>" | |
return | |
if model_type == "summarize": | |
if summarization_model: | |
summary = summarize_text(text_input) | |
yield "SUMMARY_TEXT:" + summary | |
yield "<END_STREAM>" | |
return | |
accumulated_text = "" | |
current_context = reasoning_prompt | |
overlap = 256 | |
while True: | |
if model_type == "code": | |
generator = sample_sequence_codegen(current_context, codegen_model, codegen_tokenizer, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device) | |
elif model_type == "gpt2": | |
generator = sample_sequence(current_context, model_gpt2, enc, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device) | |
chunk_text = "" | |
finished = False | |
for token in generator: | |
if token == "<END_STREAM>": | |
finished = True | |
break | |
chunk_text += token | |
if accumulated_text: | |
overlap_text = accumulated_text[-overlap:] | |
if chunk_text.startswith(overlap_text): | |
chunk_text = chunk_text[len(overlap_text):] | |
accumulated_text += chunk_text | |
yield chunk_text | |
if finished: | |
yield "<END_STREAM>" | |
break | |
current_context = accumulated_text[-overlap:] if len(accumulated_text) > overlap else accumulated_text | |