File size: 3,630 Bytes
7b74407
1c817fd
 
15faeca
1c817fd
 
 
 
 
 
 
15faeca
1c817fd
15faeca
3e43e49
7b74407
1c817fd
 
7b74407
1c817fd
 
7b74407
1c817fd
 
 
 
 
 
 
 
7b74407
e83e49f
1c817fd
 
 
 
 
 
7b74407
 
 
 
e83e49f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import threading, queue, time, os, nltk, re, json
from flask import Flask
from flask_cors import CORS
from api import *
from extensions import *
from constants import *
from configs import *
from tokenxxx import *
from models import *
from model_loader import *
from utils import *
from background_tasks import *
from text_generation import *
from sadtalker_utils import *

state_dict, enc, config, model_gpt2, device, news_clf, tfidf_vectorizer, text_queue, categories, background_threads, feedback_queue, reasoning_queue, seen_responses, dialogue_history, vocabulary, word_to_index, index_to_word, translation_model, sp, codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index, summarization_model, summarization_vocabulary, summarization_word_to_index, summarization_index_to_word, sadtalker_instance, imagegen_model, image_to_3d_model, text_to_video_model, stream_type, sentiment_model, stt_model, tts_model, musicgen_model, xtts_model = None, None, None, None, torch.device("cuda" if torch.cuda.is_available() else "cpu"), None, None, queue.Queue(), None, [], queue.Queue(), queue.Queue(), set(), [], set(), {}, [], None, None, None, None, None, None, set(), {}, [], None, None, None, None, "text", None, None, None, None, None

def load_models():
    global model_gpt2, enc, translation_model, codegen_model, codegen_tokenizer, summarization_model, imagegen_model, image_to_3d_model, text_to_video_model, sadtalker_instance, sentiment_model, stt_model, tts_model, musicgen_model, xtts_model
    model_gpt2, enc = initialize_gpt2_model(GPT2_FOLDER, {MODEL_FILE: MODEL_URL, ENCODER_FILE: ENCODER_URL, VOCAB_FILE: VOCAB_URL, CONFIG_FILE: GPT2CONFHG})
    translation_model = initialize_translation_model(TRANSLATION_FOLDER, TRANSLATION_MODEL_FILES_URLS)
    codegen_model, codegen_tokenizer, _, _, _ = initialize_codegen_model(CODEGEN_FOLDER, CODEGEN_FILES_URLS)
    summarization_model, _, _, _ = initialize_summarization_model(SUMMARIZATION_FOLDER, SUMMARIZATION_FILES_URLS)
    imagegen_model = initialize_imagegen_model(IMAGEGEN_FOLDER, IMAGEGEN_FILES_URLS)
    image_to_3d_model = initialize_image_to_3d_model(IMAGE_TO_3D_FOLDER, IMAGE_TO_3D_FILES_URLS)
    text_to_video_model = initialize_text_to_video_model(TEXT_TO_VIDEO_FOLDER, TEXT_TO_VIDEO_FILES_URLS)
    sentiment_model = initialize_sentiment_model(SENTIMENT_FOLDER, SENTIMENT_FILES_URLS)
    stt_model = initialize_stt_model(STT_FOLDER, STT_FILES_URLS)
    tts_model = initialize_tts_model(TTS_FOLDER, TTS_FILES_URLS)
    musicgen_model = initialize_musicgen_model(MUSICGEN_FOLDER, MUSICGEN_FILES_URLS)
    xtts_model = initialize_xtts_model(XTTS_FOLDER, XTTS_FILES_URLS)
    sadtalker_instance = SadTalker(checkpoint_path='./checkpoints', config_path='./src/config')

if __name__ == "__main__":
    nltk.download('punkt')
    load_models()
    categories = ['Category1', 'Category2', 'Category3', 'Category4', 'Category5']
    import background_tasks
    background_tasks.categories = categories; background_tasks.text_queue = text_queue; background_tasks.reasoning_queue = reasoning_queue
    background_threads.append(threading.Thread(target=generate_and_queue_text, args=('en',), daemon=True)); background_threads.append(threading.Thread(target=generate_and_queue_text, args=('es',), daemon=True))
    background_threads.append(threading.Thread(target=background_training, daemon=True)); background_threads.append(threading.Thread(target=background_reasoning_queue, daemon=True))
    for thread in background_threads: thread.start()
    app.run(host='0.0.0.0', port=7860)