Spaces:
Running
Running
import time, threading, queue, uuid, unicodedata, re | |
from deep_translator import GoogleTranslator | |
from duckduckgo_search import DDGS | |
import nltk, torch, torch.nn as nn | |
nltk.download('punkt') | |
categories = ['News', 'Sports', 'Entertainment'] | |
TEXT_GENERATION_RATE = 10 | |
text_queue = queue.Queue() | |
reasoning_queue = queue.Queue() | |
feedback_queue = queue.Queue() | |
vocabulary = ["<PAD>", "<EOS>"] | |
word_to_index = {word: idx for idx, word in enumerate(vocabulary)} | |
seen_responses = set() | |
news_clf = None | |
class SimpleClassifier(nn.Module): | |
def __init__(self, vocab_size, num_classes, embedding_dim=128): | |
super().__init__() | |
self.embedding = nn.Embedding(vocab_size, embedding_dim) | |
self.fc = nn.Linear(embedding_dim, num_classes) | |
def forward(self, x): | |
embedded = self.embedding(x) | |
pooled = embedded.mean(dim=1) | |
out = self.fc(pooled) | |
return out | |
def tokenize_text(text): return nltk.word_tokenize(text) | |
def update_vocabulary(tokens): global vocabulary, word_to_index; for token in tokens: if token not in word_to_index: word_to_index[token] = len(vocabulary); vocabulary.append(token) | |
def text_to_vector(text): tokens = tokenize_text(text); update_vocabulary(tokens); indices = [word_to_index.get(token, 0) for token in tokens]; return torch.tensor(indices, dtype=torch.long).unsqueeze(0) | |
def generate_and_queue_text(language): | |
global categories, text_queue | |
num_categories = len(categories); num_texts_per_category = TEXT_GENERATION_RATE // (2 * num_categories) | |
while True: | |
for category in categories: | |
for _ in range(num_texts_per_category): | |
uid = uuid.uuid4(); base_text = f"Category: {category}. ID:{uid}" | |
try: translator = GoogleTranslator(source='auto', target=language); text = translator.translate(base_text) | |
except: text = base_text | |
processed_text = ''.join(c for c in unicodedata.normalize('NFKC', text) if c.isprintable()); text_queue.put((processed_text, category)); time.sleep(0) | |
def background_training(): | |
global categories, news_clf, feedback_queue, vocabulary | |
if categories is None: categories = ['DefaultCategory'] | |
num_classes = len(categories); learning_rate = 0.01; epochs = 1 | |
if news_clf is None: news_clf = SimpleClassifier(len(vocabulary), num_classes) | |
optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate); criterion = nn.CrossEntropyLoss() | |
while True: | |
try: | |
feedback_item = feedback_queue.get(timeout=10) | |
if feedback_item: | |
input_text, generated_text = feedback_item; input_vector = text_to_vector(input_text) | |
if len(vocabulary) == 0: vocabulary.extend(["<PAD>", "<EOS>"]); news_clf = SimpleClassifier(len(vocabulary), num_classes); optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate) | |
if input_vector.size(0) != len(vocabulary) and len(vocabulary) > 0: news_clf = SimpleClassifier(len(vocabulary), num_classes); optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate); input_vector = text_to_vector(input_text) | |
tokens = tokenize_text(input_text); update_vocabulary(tokens); tokens_indices = [word_to_index.get(word, 0) for word in tokens] | |
input_tensor = torch.tensor([tokens_indices], dtype=torch.long); target_index = categories.index(generated_text) if generated_text in categories else 0 | |
target_category_index = torch.tensor([target_index], dtype=torch.long) | |
if num_classes <= 1: num_classes = 2; news_clf.fc = nn.Linear(128, num_classes) | |
for _ in range(epochs): optimizer.zero_grad(); output = news_clf(input_tensor); loss = criterion(output, target_category_index); loss.backward(); optimizer.step() | |
feedback_queue.task_done() | |
except queue.Empty: pass | |
except: time.sleep(5) | |
def perform_reasoning_stream(text_input, temperature=0.7, top_k=40, top_p=0.0, repetition_penalty=1.2): | |
for token in sample_sequence(text_input, model_gpt2, enc, length=999999999, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, device=device): | |
if token == "<END_STREAM>": yield "<END_STREAM>"; break | |
yield token + " " | |
def background_reasoning_queue(): | |
global reasoning_queue, seen_responses | |
while True: | |
try: | |
item = reasoning_queue.get(timeout=1) | |
if item is None: reasoning_queue.task_done(); continue | |
text_input = item.get('text_input'); temperature = item.get('temperature', 0.7); top_k = item.get('top_k', 40); top_p = item.get('top_p', 0.0); repetition_penalty = item.get('repetition_penalty', 1.2) | |
resp_queue = item.get('response_queue', queue.Queue()) | |
if not text_input: resp_queue.put({"error": "Empty text input received."}); reasoning_queue.task_done(); continue | |
generated_text_stream = perform_reasoning_stream(text_input, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty) | |
full_response = ""; | |
for chunk in generated_text_stream: | |
if chunk == "<END_STREAM>": break | |
full_response += chunk | |
cleaned_response = re.sub(r'\s+(?=[.,,。])', '', full_response.replace("<|endoftext|>", "").strip()) | |
if cleaned_response in seen_responses: final_response = "**Response is repetitive. Please try again or rephrase your query.**"; resp_queue.put({"text": final_response}) | |
else: seen_responses.add(cleaned_response); final_response = cleaned_response; resp_queue.put({"text": final_response}) | |
reasoning_queue.task_done() | |
except queue.Empty: pass | |
except Exception as e: | |
try: resp_queue.put({"error": str(e)}) | |
except: pass | |
if reasoning_queue and not reasoning_queue.empty(): reasoning_queue.task_done() | |