Spaces:
Running
Running
from main import * | |
from tts_api import * | |
from stt_api import * | |
from sentiment_api import * | |
from imagegen_api import * | |
from musicgen_api import * | |
from translation_api import * | |
from codegen_api import * | |
from text_to_video_api import * | |
from summarization_api import * | |
from image_to_3d_api import * | |
from flask import Flask, request, jsonify, Response, send_file, stream_with_context | |
from flask_cors import CORS | |
import torch | |
import torch.nn.functional as F | |
import torchaudio | |
import numpy as np | |
from PIL import Image | |
import io | |
import tempfile | |
import queue | |
import json | |
import base64 | |
from markupsafe import Markup | |
from markupsafe import escape | |
app = Flask(__name__) | |
CORS(app) | |
html_code = """<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>AI Conversational Avatar</title> | |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css"/> | |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css"/> | |
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script> | |
<style> | |
body { | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
background: #f0f0f0; | |
color: #333; | |
margin: 0; | |
padding: 0; | |
display: flex; | |
flex-direction: column; | |
align-items: center; | |
min-height: 100vh; | |
} | |
.container { | |
width: 95%; | |
max-width: 900px; | |
padding: 20px; | |
background-color: #fff; | |
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); | |
border-radius: 8px; | |
margin-top: 20px; | |
margin-bottom: 20px; | |
display: flex; | |
flex-direction: column; | |
} | |
.header { | |
text-align: center; | |
margin-bottom: 20px; | |
} | |
.header h1 { | |
font-size: 2em; | |
color: #333; | |
} | |
.form-group { | |
margin-bottom: 15px; | |
} | |
.form-group textarea, .form-group input[type="text"] { | |
width: 100%; | |
padding: 10px; | |
border: 1px solid #ccc; | |
border-radius: 5px; | |
font-size: 16px; | |
box-sizing: border-box; | |
} | |
button, #recordButton, #stopButton { | |
padding: 10px 15px; | |
border: none; | |
border-radius: 5px; | |
background-color: #007bff; | |
color: white; | |
font-size: 18px; | |
cursor: pointer; | |
transition: background-color 0.3s ease; | |
margin-right: 5px; | |
} | |
button:hover, #recordButton:hover, #stopButton:hover { | |
background-color: #0056b3; | |
} | |
#output { | |
margin-top: 20px; | |
padding: 15px; | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
background-color: #f9f9f9; | |
white-space: pre-wrap; | |
word-break: break-word; | |
overflow-y: auto; | |
max-height: 300px; | |
} | |
#videoOutput { | |
margin-top: 20px; | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
overflow: hidden; | |
} | |
#videoOutput video { | |
width: 100%; | |
display: block; | |
} | |
#animatedText { | |
position: fixed; | |
top: 20px; | |
left: 20px; | |
font-size: 1.5em; | |
color: rgba(0, 0, 0, 0.1); | |
pointer-events: none; | |
z-index: -1; | |
} | |
#transcriptionOutput { | |
margin-top: 10px; | |
padding: 10px; | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
background-color: #f9f9f9; | |
font-size: 14px; | |
word-break: break-word; | |
} | |
@media (max-width: 768px) { | |
.container { | |
width: 98%; | |
margin-top: 10px; | |
margin-bottom: 10px; | |
padding: 15px; | |
} | |
.header h1 { | |
font-size: 1.8em; | |
} | |
.form-group textarea, .form-group input[type="text"] { | |
font-size: 14px; | |
padding: 8px; | |
} | |
button, #recordButton, #stopButton { | |
font-size: 16px; | |
padding: 8px 12px; | |
} | |
#output, #transcriptionOutput { | |
font-size: 14px; | |
padding: 10px; | |
margin-top: 15px; | |
} | |
} | |
</style> | |
</head> | |
<body> | |
<div id="animatedText" class="animated-text animate__animated animate__fadeIn animate__infinite infinite">AI POWERED</div> | |
<div class="container"> | |
<div class="header animate__animated animate__fadeInDown"> | |
<h1>Conversational Avatar</h1> | |
</div> | |
<div class="form-group animate__animated animate__fadeInLeft"> | |
<textarea id="textInput" rows="3" placeholder="Or type your request here"></textarea> | |
</div> | |
<div class="form-group animate__animated animate__fadeInRight" style="text-align: center;"> | |
<button onclick="generateResponse()" class="animate__animated animate__fadeInUp">Generate Avatar Response</button> | |
</div> | |
<div style="text-align: center; margin-bottom: 15px;"> | |
<button id="recordButton" class="animate__animated animate__fadeInUp"><i class="fas fa-microphone"></i> Start Recording</button> | |
<button id="stopButton" class="animate__animated animate__fadeInUp" disabled><i class="fas fa-stop-circle"></i> Stop Recording</button> | |
</div> | |
<div id="transcriptionOutput" class="animate__animated animate__fadeIn"> | |
<strong>Transcription:</strong> | |
<span id="transcriptionText"></span> | |
</div> | |
<div id="output" class="animate__animated animate__fadeIn"> | |
<strong>Response:</strong><br> | |
<span id="responseText"></span> | |
</div> | |
<div id="videoOutput" class="animate__animated animate__fadeIn"> | |
<video id="avatarVideo" controls></video> | |
</div> | |
</div> | |
<script> | |
let mediaRecorder; | |
let audioChunks = []; | |
let lastResponse = ""; | |
let accumulatedText = ""; | |
let eventSource = null; | |
let audioURL; | |
const recordButton = document.getElementById('recordButton'); | |
const stopButton = document.getElementById('stopButton'); | |
const transcriptionTextSpan = document.getElementById('transcriptionText'); | |
const responseTextSpan = document.getElementById('responseText'); | |
const avatarVideoPlayer = document.getElementById('avatarVideo'); | |
const textInputField = document.getElementById('textInput'); | |
recordButton.onclick = async () => { | |
try { | |
const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); | |
mediaRecorder = new MediaRecorder(stream); | |
audioChunks = []; | |
mediaRecorder.ondataavailable = event => { | |
audioChunks.push(event.data); | |
}; | |
mediaRecorder.onstop = async () => { | |
const audioBlob = new Blob(audioChunks, { type: 'audio/wav' }); | |
const formData = new FormData(); | |
formData.append('audio', audioBlob, 'recording.wav'); | |
transcriptionTextSpan.innerText = "Transcribing..."; | |
responseTextSpan.innerText = ""; | |
avatarVideoPlayer.src = ""; | |
try { | |
const sttResponse = await fetch('/api/v1/stt', { | |
method: 'POST', | |
body: formData | |
}); | |
if (!sttResponse.ok) { | |
throw new Error(\`HTTP error! status: ${sttResponse.status}\`); | |
} | |
const sttData = await sttResponse.json(); | |
const transcribedText = sttData.text; | |
transcriptionTextSpan.innerText = transcribedText || "Transcription failed."; | |
if (transcribedText) { | |
await generateAvatarVideoResponse(transcribedText); | |
} | |
} catch (error) { | |
console.error("STT or subsequent error:", error); | |
transcriptionTextSpan.innerText = "Transcription error."; | |
responseTextSpan.innerText = "Error processing audio."; | |
} finally { | |
recordButton.disabled = false; | |
stopButton.disabled = true; | |
} | |
}; | |
recordButton.disabled = true; | |
stopButton.disabled = false; | |
transcriptionTextSpan.innerText = "Recording..."; | |
mediaRecorder.start(); | |
} catch (error) { | |
console.error("Error accessing microphone:", error); | |
transcriptionTextSpan.innerText = "Microphone access denied or error."; | |
recordButton.disabled = false; | |
stopButton.disabled = true; | |
} | |
}; | |
stopButton.onclick = () => { | |
if (mediaRecorder && mediaRecorder.state === "recording") { | |
transcriptionTextSpan.innerText = "Processing..."; | |
mediaRecorder.stop(); | |
recordButton.disabled = true; | |
stopButton.disabled = true; | |
} | |
}; | |
async function generateResponse() { | |
const inputText = textInputField.value; | |
if (!inputText.trim()) { | |
alert("Please enter text or record audio."); | |
return; | |
} | |
transcriptionTextSpan.innerText = inputText; | |
await generateAvatarVideoResponse(inputText); | |
} | |
async function generateAvatarVideoResponse(inputText) { | |
responseTextSpan.innerText = "Generating response..."; | |
avatarVideoPlayer.src = ""; | |
accumulatedText = ""; | |
lastResponse = ""; | |
const temp = 0.7; | |
const top_k_val = 40; | |
const top_p_val = 0.0; | |
const repetition_penalty_val = 1.2; | |
const requestData = { | |
text: inputText, | |
temp: temp, | |
top_k: top_k_val, | |
top_p: top_p_val, | |
reppenalty: repetition_penalty_val | |
}; | |
if (eventSource) { | |
eventSource.close(); | |
} | |
eventSource = new EventSource('/api/v1/generate_stream?' + new URLSearchParams(requestData).toString()); | |
eventSource.onmessage = async function(event) { | |
if (event.data === "<END_STREAM>") { | |
eventSource.close(); | |
const currentResponse = accumulatedText.replace("<|endoftext|>", "").replace(/\s+(?=[.,,。])/g, '').trim(); | |
if (currentResponse === lastResponse.trim()) { | |
accumulatedText = "**Response is repetitive. Please try again or rephrase your query.**"; | |
} else { | |
lastResponse = currentResponse; | |
} | |
responseTextSpan.innerHTML = marked.parse(accumulatedText); | |
try { | |
const ttsResponse = await fetch('/api/v1/tts', { | |
method: 'POST', | |
headers: { | |
'Content-Type': 'application/json' | |
}, | |
body: JSON.stringify({ text: currentResponse }) | |
}); | |
if (!ttsResponse.ok) { | |
throw new Error(\`TTS HTTP error! status: ${ttsResponse.status}\`); | |
} | |
const ttsBlob = await ttsResponse.blob(); | |
audioURL = URL.createObjectURL(ttsBlob); | |
const sadTalkerResponse = await fetch('/api/v1/sadtalker', { | |
method: 'POST', | |
body: new URLSearchParams({ | |
'source_image': './examples/source_image/full_body_female.png', | |
'driven_audio': audioURL, | |
'preprocess': 'full', | |
'still_mode': false, | |
'use_enhancer': true | |
}) | |
}); | |
if (!sadTalkerResponse.ok) { | |
throw new Error(\`SadTalker HTTP error! status: ${sadTalkerResponse.status}\`); | |
} | |
const sadTalkerData = await sadTalkerResponse.json(); | |
const videoURL = sadTalkerData.video_url; | |
avatarVideoPlayer.src = videoURL; | |
} catch (ttsError) { | |
console.error("TTS or SadTalker error:", ttsError); | |
responseTextSpan.innerHTML += "<br><br>Error generating audio or video avatar."; | |
} | |
return; | |
} | |
accumulatedText += event.data; | |
let partialText = accumulatedText.replace("<|endoftext|>", "").replace(/\s+(?=[.,,。])/g, '').trim(); | |
responseTextSpan.innerHTML = marked.parse(partialText); | |
}; | |
eventSource.onerror = function(error) { | |
console.error("SSE error", error); | |
eventSource.close(); | |
responseTextSpan.innerText = "Error generating response stream."; | |
}; | |
const outputDiv = document.getElementById("output"); | |
outputDiv.classList.add("show"); | |
} | |
</script> | |
</body> | |
</html> | |
""" | |
feedback_queue = queue.Queue() | |
def index(): | |
return html_code | |
def generate_stream(): | |
text = request.args.get("text", "") | |
temp = float(request.args.get("temp", 0.7)) | |
top_k = int(request.args.get("top_k", 40)) | |
top_p = float(request.args.get("top_p", 0.0)) | |
reppenalty = float(request.args.get("reppenalty", 1.2)) | |
response_queue = queue.Queue() | |
reasoning_queue.put({ | |
'text_input': text, | |
'temperature': temp, | |
'top_k': top_k, | |
'top_p': top_p, | |
'repetition_penalty': reppenalty, | |
'response_queue': response_queue | |
}) | |
def event_stream(): | |
while True: | |
output = response_queue.get() | |
if "error" in output: | |
yield "data: <ERROR>\n\n" | |
break | |
text_chunk = output.get("text") | |
if text_chunk: | |
for word in text_chunk.split(' '): | |
clean_word = word.strip() | |
if clean_word: | |
yield "data: " + clean_word + "\n\n" | |
yield "data: <END_STREAM>\n\n" | |
break | |
return Response(event_stream(), mimetype="text/event-stream") | |
def generate(): | |
data = request.get_json() | |
text = data.get("text", "") | |
temp = float(data.get("temp", 0.7)) | |
top_k = int(data.get("top_k", 40)) | |
top_p = float(data.get("top_p", 0.0)) | |
reppenalty = float(data.get("reppenalty", 1.2)) | |
response_queue = queue.Queue() | |
reasoning_queue.put({ | |
'text_input': text, | |
'temperature': temp, | |
'top_k': top_k, | |
'top_p': top_p, | |
'repetition_penalty': reppenalty, | |
'response_queue': response_queue | |
}) | |
output = response_queue.get() | |
if "error" in output: | |
return jsonify({"error": output["error"]}), 500 | |
result_text = output.get("text", "").strip() | |
return jsonify({"response": result_text}) | |
def feedback(): | |
data = request.get_json() | |
feedback_text = data.get("feedback_text") | |
correct_category = data.get("correct_category") | |
if feedback_text and correct_category: | |
feedback_queue.put((feedback_text, correct_category)) | |
return jsonify({"status": "feedback received"}) | |
return jsonify({"status": "feedback failed"}), 400 | |
def tts_api(): | |
return tts_route() | |
def stt_api(): | |
return stt_route() | |
def sentiment_api(): | |
return sentiment_route() | |
def imagegen_api(): | |
return imagegen_route() | |
def musicgen_api(): | |
return musicgen_route() | |
def translation_api(): | |
return translation_route() | |
def codegen_api(): | |
return codegen_route() | |
def text_to_video_api(): | |
return text_to_video_route() | |
def summarization_api(): | |
return summarization_route() | |
def image_to_3d_api(): | |
return image_to_3d_route() | |
def sadtalker(): | |
from sadtalker_api import router as sadtalker_router | |
return sadtalker_router.create_video() | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) | |