Hhhh / background_tasks.py
Kfjjdjdjdhdhd's picture
Upload 26 files
7b74407 verified
raw
history blame contribute delete
6.05 kB
import time, threading, queue, uuid, unicodedata, re
from deep_translator import GoogleTranslator
from duckduckgo_search import DDGS
import nltk, torch, torch.nn as nn
nltk.download('punkt')
categories = ['News', 'Sports', 'Entertainment']
TEXT_GENERATION_RATE = 10
text_queue = queue.Queue()
reasoning_queue = queue.Queue()
feedback_queue = queue.Queue()
vocabulary = ["<PAD>", "<EOS>"]
word_to_index = {word: idx for idx, word in enumerate(vocabulary)}
seen_responses = set()
news_clf = None
class SimpleClassifier(nn.Module):
def __init__(self, vocab_size, num_classes, embedding_dim=128):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.fc = nn.Linear(embedding_dim, num_classes)
def forward(self, x):
embedded = self.embedding(x)
pooled = embedded.mean(dim=1)
out = self.fc(pooled)
return out
def tokenize_text(text): return nltk.word_tokenize(text)
def update_vocabulary(tokens): global vocabulary, word_to_index; for token in tokens: if token not in word_to_index: word_to_index[token] = len(vocabulary); vocabulary.append(token)
def text_to_vector(text): tokens = tokenize_text(text); update_vocabulary(tokens); indices = [word_to_index.get(token, 0) for token in tokens]; return torch.tensor(indices, dtype=torch.long).unsqueeze(0)
def generate_and_queue_text(language):
global categories, text_queue
num_categories = len(categories); num_texts_per_category = TEXT_GENERATION_RATE // (2 * num_categories)
while True:
for category in categories:
for _ in range(num_texts_per_category):
uid = uuid.uuid4(); base_text = f"Category: {category}. ID:{uid}"
try: translator = GoogleTranslator(source='auto', target=language); text = translator.translate(base_text)
except: text = base_text
processed_text = ''.join(c for c in unicodedata.normalize('NFKC', text) if c.isprintable()); text_queue.put((processed_text, category)); time.sleep(0)
def background_training():
global categories, news_clf, feedback_queue, vocabulary
if categories is None: categories = ['DefaultCategory']
num_classes = len(categories); learning_rate = 0.01; epochs = 1
if news_clf is None: news_clf = SimpleClassifier(len(vocabulary), num_classes)
optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate); criterion = nn.CrossEntropyLoss()
while True:
try:
feedback_item = feedback_queue.get(timeout=10)
if feedback_item:
input_text, generated_text = feedback_item; input_vector = text_to_vector(input_text)
if len(vocabulary) == 0: vocabulary.extend(["<PAD>", "<EOS>"]); news_clf = SimpleClassifier(len(vocabulary), num_classes); optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
if input_vector.size(0) != len(vocabulary) and len(vocabulary) > 0: news_clf = SimpleClassifier(len(vocabulary), num_classes); optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate); input_vector = text_to_vector(input_text)
tokens = tokenize_text(input_text); update_vocabulary(tokens); tokens_indices = [word_to_index.get(word, 0) for word in tokens]
input_tensor = torch.tensor([tokens_indices], dtype=torch.long); target_index = categories.index(generated_text) if generated_text in categories else 0
target_category_index = torch.tensor([target_index], dtype=torch.long)
if num_classes <= 1: num_classes = 2; news_clf.fc = nn.Linear(128, num_classes)
for _ in range(epochs): optimizer.zero_grad(); output = news_clf(input_tensor); loss = criterion(output, target_category_index); loss.backward(); optimizer.step()
feedback_queue.task_done()
except queue.Empty: pass
except: time.sleep(5)
def perform_reasoning_stream(text_input, temperature=0.7, top_k=40, top_p=0.0, repetition_penalty=1.2):
for token in sample_sequence(text_input, model_gpt2, enc, length=999999999, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, device=device):
if token == "<END_STREAM>": yield "<END_STREAM>"; break
yield token + " "
def background_reasoning_queue():
global reasoning_queue, seen_responses
while True:
try:
item = reasoning_queue.get(timeout=1)
if item is None: reasoning_queue.task_done(); continue
text_input = item.get('text_input'); temperature = item.get('temperature', 0.7); top_k = item.get('top_k', 40); top_p = item.get('top_p', 0.0); repetition_penalty = item.get('repetition_penalty', 1.2)
resp_queue = item.get('response_queue', queue.Queue())
if not text_input: resp_queue.put({"error": "Empty text input received."}); reasoning_queue.task_done(); continue
generated_text_stream = perform_reasoning_stream(text_input, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
full_response = "";
for chunk in generated_text_stream:
if chunk == "<END_STREAM>": break
full_response += chunk
cleaned_response = re.sub(r'\s+(?=[.,,。])', '', full_response.replace("<|endoftext|>", "").strip())
if cleaned_response in seen_responses: final_response = "**Response is repetitive. Please try again or rephrase your query.**"; resp_queue.put({"text": final_response})
else: seen_responses.add(cleaned_response); final_response = cleaned_response; resp_queue.put({"text": final_response})
reasoning_queue.task_done()
except queue.Empty: pass
except Exception as e:
try: resp_queue.put({"error": str(e)})
except: pass
if reasoning_queue and not reasoning_queue.empty(): reasoning_queue.task_done()