Hhhh / background_tasks.py
Hjgugugjhuhjggg's picture
Upload 28 files
e83e49f verified
raw
history blame
7.05 kB
import time
import threading
import queue
import uuid
import unicodedata
import re
from deep_translator import GoogleTranslator
from duckduckgo_search import DDGS
import nltk
import torch
import torch.nn as nn
import math
nltk.download('punkt')
categories = ['News', 'Sports', 'Entertainment']
TEXT_GENERATION_RATE = 10
text_queue = queue.Queue()
reasoning_queue = queue.Queue()
feedback_queue = queue.Queue()
vocabulary = ["<PAD>", "<EOS>"]
word_to_index = {word: idx for idx, word in enumerate(vocabulary)}
seen_responses = set()
news_clf = None
class SimpleClassifier(nn.Module):
def __init__(self, vocab_size, num_classes, embedding_dim=128):
super(SimpleClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.fc = nn.Linear(embedding_dim, num_classes)
def forward(self, x):
embedded = self.embedding(x)
pooled = embedded.mean(dim=1)
out = self.fc(pooled)
return out
def tokenize_text(text):
return nltk.word_tokenize(text)
def update_vocabulary(tokens):
global vocabulary, word_to_index
for token in tokens:
if token not in word_to_index:
word_to_index[token] = len(vocabulary)
vocabulary.append(token)
def text_to_vector(text):
tokens = tokenize_text(text)
update_vocabulary(tokens)
indices = [word_to_index.get(token, 0) for token in tokens]
return torch.tensor(indices, dtype=torch.long).unsqueeze(0)
def generate_and_queue_text(language):
global categories, text_queue
num_categories = len(categories)
num_texts_per_category = TEXT_GENERATION_RATE // (2 * num_categories)
while True:
for category in categories:
for _ in range(num_texts_per_category):
uid = uuid.uuid4()
base_text = f"Category: {category}. ID:{uid}"
try:
translator = GoogleTranslator(source='auto', target=language)
text = translator.translate(base_text)
except Exception:
text = base_text
processed_text = ''.join(c for c in unicodedata.normalize('NFKC', text) if c.isprintable())
text_queue.put((processed_text, category))
time.sleep(0)
def background_training():
global categories, news_clf, feedback_queue, vocabulary
if categories is None:
categories = ['DefaultCategory']
num_classes = len(categories)
learning_rate = 0.01
epochs = 1
if news_clf is None:
news_clf = SimpleClassifier(len(vocabulary), num_classes)
optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
while True:
try:
feedback_item = feedback_queue.get(timeout=10)
if feedback_item:
input_text, generated_text = feedback_item
input_vector = text_to_vector(input_text)
if len(vocabulary) == 0:
vocabulary.extend(["<PAD>", "<EOS>"])
news_clf = SimpleClassifier(len(vocabulary), num_classes)
optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
if input_vector.size(0) != len(vocabulary) and len(vocabulary) > 0:
news_clf = SimpleClassifier(len(vocabulary), num_classes)
optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
input_vector = text_to_vector(input_text)
tokens = tokenize_text(input_text)
update_vocabulary(tokens)
tokens_indices = [word_to_index.get(word, 0) for word in tokens]
input_tensor = torch.tensor([tokens_indices], dtype=torch.long)
target_index = categories.index(generated_text) if generated_text in categories else 0
target_category_index = torch.tensor([target_index], dtype=torch.long)
if num_classes <= 1:
num_classes = 2
news_clf.fc = nn.Linear(128, num_classes)
for _ in range(epochs):
optimizer.zero_grad()
output = news_clf(input_tensor)
loss = criterion(output, target_category_index)
loss.backward()
optimizer.step()
feedback_queue.task_done()
except queue.Empty:
pass
except Exception:
time.sleep(5)
def perform_reasoning_stream(text_input, temperature=0.7, top_k=40, top_p=0.0, repetition_penalty=1.2):
for token in sample_sequence(text_input, model_gpt2, enc, length=999999999, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, device=device):
if token == "<END_STREAM>":
yield "<END_STREAM>"
break
yield token + " "
def background_reasoning_queue():
global reasoning_queue, seen_responses
while True:
try:
item = reasoning_queue.get(timeout=1)
if item is None:
reasoning_queue.task_done()
continue
text_input = item.get('text_input')
temperature = item.get('temperature', 0.7)
top_k = item.get('top_k', 40)
top_p = item.get('top_p', 0.0)
repetition_penalty = item.get('repetition_penalty', 1.2)
resp_queue = item.get('response_queue', queue.Queue())
if not text_input:
resp_queue.put({"error": "Empty text input received."})
reasoning_queue.task_done()
continue
generated_text_stream = perform_reasoning_stream(text_input, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
full_response = ""
for chunk in generated_text_stream:
if chunk == "<END_STREAM>":
break
full_response += chunk
cleaned_response = re.sub(r'\s+(?=[.,,。])', '', full_response.replace("<|endoftext|>", "").strip())
if cleaned_response in seen_responses:
final_response = "**Response is repetitive. Please try again or rephrase your query.**";
resp_queue.put({"text": final_response})
else:
seen_responses.add(cleaned_response)
final_response = cleaned_response
resp_queue.put({"text": final_response})
reasoning_queue.task_done()
except queue.Empty:
pass
except Exception as e:
try:
resp_queue.put({"error": str(e)})
except Exception:
pass
if reasoning_queue and not reasoning_queue.empty():
reasoning_queue.task_done()