Hhhh / text_generation.py
Hjgugugjhuhjggg's picture
Update text_generation.py
9cd71e4 verified
raw
history blame
7.4 kB
import torch
import torch.nn.functional as F
from tqdm import trange
import time
from tokenxxx import *
from main import *
from duckduckgo_search import DDGS
try:
END_OF_TEXT_TOKEN
except NameError:
END_OF_TEXT_TOKEN = ""
try:
SYSTEM_PROMPT
except NameError:
SYSTEM_PROMPT = "Sistema: Proporcione respuestas ultra rápidas, coherentes, similares, precisas y con sentido, con razonamiento lógico y profundo."
try:
MAX_XDD
except NameError:
MAX_XDD = 5
try:
codegen_model
except NameError:
codegen_model = None
try:
codegen_tokenizer
except NameError:
codegen_tokenizer = None
try:
summarization_model
except NameError:
summarization_model = None
try:
summarization_tokenizer
except NameError:
summarization_tokenizer = None
try:
model_gpt2
except NameError:
model_gpt2 = None
try:
enc
except NameError:
enc = None
try:
device
except NameError:
device = "cpu"
if torch.device(device).type == "cuda":
torch.backends.cudnn.benchmark = True
MAX_GENERATION_LENGTH = 512
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
top_k = min(top_k, logits.size(-1))
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., [-1]]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
def _generate_sequence(model_call, context_tensor, generated, decode_fn, end_token_condition, temperature, top_k, top_p, repetition_penalty, max_length):
past_key_values = None
last_token = None
repetition_count = 0
for _ in range(max_length):
try:
outputs = model_call(context_tensor, past_key_values)
except Exception as e:
yield "<ERROR:" + str(e) + ">"
yield "<END_STREAM>"
return
next_token_logits = outputs[0][:, -1, :] / temperature
past_key_values = outputs[1]
for token_index in set(generated):
next_token_logits[0, token_index] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
if temperature == 0:
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
else:
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
token_id = next_token.tolist()[0][0]
if token_id == last_token:
repetition_count += 1
else:
repetition_count = 0
last_token = token_id
if repetition_count >= 10:
yield "<END_STREAM>"
return
generated.append(token_id)
token_decoded = decode_fn(token_id)
yield token_decoded
if end_token_condition(token_id):
yield "<END_STREAM>"
return
def sample_sequence(prompt, model, enc, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
context_tokens = enc.encode(prompt)
context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
return _generate_sequence(
lambda ct, past: model(ct, past_key_values=past),
context_tensor,
list(context_tokens),
lambda token: enc.decode([token]),
lambda token: token == enc.encoder[END_OF_TEXT_TOKEN],
temperature, top_k, top_p, repetition_penalty, max_length
)
def sample_sequence_codegen(prompt, model, tokenizer, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
context_tokens = tokenizer.encode(prompt)
context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
return _generate_sequence(
lambda ct, past: model(input_ids=ct, past_key_values=past, labels=None),
context_tensor,
list(context_tokens),
lambda token: tokenizer.decode([token]),
lambda token: token == 50256,
temperature, top_k, top_p, repetition_penalty, max_length
)
def summarize_text(text):
if summarization_model and summarization_tokenizer:
input_ids = summarization_tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=1024).to(device)
summary_ids = summarization_model.generate(
input_ids,
max_length=150,
min_length=40,
length_penalty=2.0,
num_beams=4,
early_stopping=True
)
return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return text[:300] + "..." if len(text) > 300 else text
def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty, prev_context=""):
initial_prompt = SYSTEM_PROMPT + "\n\nUser: " + text_input + "\nAssistant:"
reasoning_prompt = prev_context if prev_context else initial_prompt
ddgs = DDGS()
search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)]
if search_results:
reasoning_prompt += "\nWeb Search Results:\n"
for result in search_results:
reasoning_prompt += "- " + result['body'] + "\n"
reasoning_prompt += "\n"
if "code" in text_input.lower() or "program" in text_input.lower():
model_type = "code"
elif "summarize" in text_input.lower() or "summary" in text_input.lower():
model_type = "summarize"
elif model_gpt2 and enc:
model_type = "gpt2"
else:
yield "<ERROR: No se encontró un modelo adecuado>"
yield "<END_STREAM>"
return
if model_type == "summarize":
if summarization_model:
summary = summarize_text(text_input)
yield "SUMMARY_TEXT:" + summary
yield "<END_STREAM>"
return
accumulated_text = ""
current_context = reasoning_prompt
overlap = 256
while True:
if model_type == "code":
generator = sample_sequence_codegen(current_context, codegen_model, codegen_tokenizer, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device)
elif model_type == "gpt2":
generator = sample_sequence(current_context, model_gpt2, enc, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device)
chunk_text = ""
finished = False
for token in generator:
if token == "<END_STREAM>":
finished = True
break
chunk_text += token
if accumulated_text:
overlap_text = accumulated_text[-overlap:]
if chunk_text.startswith(overlap_text):
chunk_text = chunk_text[len(overlap_text):]
accumulated_text += chunk_text
yield chunk_text
if finished:
yield "<END_STREAM>"
break
current_context = accumulated_text[-overlap:] if len(accumulated_text) > overlap else accumulated_text