Hhhh / api.py
Kfjjdjdjdhdhd's picture
Update api.py
5b4adb6 verified
raw
history blame
17.1 kB
from main import *
from tts_api import *
from stt_api import *
from sentiment_api import *
from imagegen_api import *
from musicgen_api import *
from translation_api import *
from codegen_api import *
from text_to_video_api import *
from summarization_api import *
from image_to_3d_api import *
from flask import Flask, request, jsonify, Response, send_file, stream_with_context
from flask_cors import CORS
import torch
import torch.nn.functional as F
import torchaudio
import numpy as np
from PIL import Image
import io
import tempfile
import queue
import json
import base64
from markupsafe import Markup
from markupsafe import escape
app = Flask(__name__)
CORS(app)
html_code = """<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AI Conversational Avatar</title>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css"/>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css"/>
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<style>
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: #f0f0f0;
color: #333;
margin: 0;
padding: 0;
display: flex;
flex-direction: column;
align-items: center;
min-height: 100vh;
}
.container {
width: 95%;
max-width: 900px;
padding: 20px;
background-color: #fff;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
border-radius: 8px;
margin-top: 20px;
margin-bottom: 20px;
display: flex;
flex-direction: column;
}
.header {
text-align: center;
margin-bottom: 20px;
}
.header h1 {
font-size: 2em;
color: #333;
}
.form-group {
margin-bottom: 15px;
}
.form-group textarea, .form-group input[type="text"] {
width: 100%;
padding: 10px;
border: 1px solid #ccc;
border-radius: 5px;
font-size: 16px;
box-sizing: border-box;
}
button, #recordButton, #stopButton {
padding: 10px 15px;
border: none;
border-radius: 5px;
background-color: #007bff;
color: white;
font-size: 18px;
cursor: pointer;
transition: background-color 0.3s ease;
margin-right: 5px;
}
button:hover, #recordButton:hover, #stopButton:hover {
background-color: #0056b3;
}
#output {
margin-top: 20px;
padding: 15px;
border: 1px solid #ddd;
border-radius: 5px;
background-color: #f9f9f9;
white-space: pre-wrap;
word-break: break-word;
overflow-y: auto;
max-height: 300px;
}
#videoOutput {
margin-top: 20px;
border: 1px solid #ddd;
border-radius: 5px;
overflow: hidden;
}
#videoOutput video {
width: 100%;
display: block;
}
#animatedText {
position: fixed;
top: 20px;
left: 20px;
font-size: 1.5em;
color: rgba(0, 0, 0, 0.1);
pointer-events: none;
z-index: -1;
}
#transcriptionOutput {
margin-top: 10px;
padding: 10px;
border: 1px solid #ddd;
border-radius: 5px;
background-color: #f9f9f9;
font-size: 14px;
word-break: break-word;
}
@media (max-width: 768px) {
.container {
width: 98%;
margin-top: 10px;
margin-bottom: 10px;
padding: 15px;
}
.header h1 {
font-size: 1.8em;
}
.form-group textarea, .form-group input[type="text"] {
font-size: 14px;
padding: 8px;
}
button, #recordButton, #stopButton {
font-size: 16px;
padding: 8px 12px;
}
#output, #transcriptionOutput {
font-size: 14px;
padding: 10px;
margin-top: 15px;
}
}
</style>
</head>
<body>
<div id="animatedText" class="animated-text animate__animated animate__fadeIn animate__infinite infinite">AI POWERED</div>
<div class="container">
<div class="header animate__animated animate__fadeInDown">
<h1>Conversational Avatar</h1>
</div>
<div class="form-group animate__animated animate__fadeInLeft">
<textarea id="textInput" rows="3" placeholder="Or type your request here"></textarea>
</div>
<div class="form-group animate__animated animate__fadeInRight" style="text-align: center;">
<button onclick="generateResponse()" class="animate__animated animate__fadeInUp">Generate Avatar Response</button>
</div>
<div style="text-align: center; margin-bottom: 15px;">
<button id="recordButton" class="animate__animated animate__fadeInUp"><i class="fas fa-microphone"></i> Start Recording</button>
<button id="stopButton" class="animate__animated animate__fadeInUp" disabled><i class="fas fa-stop-circle"></i> Stop Recording</button>
</div>
<div id="transcriptionOutput" class="animate__animated animate__fadeIn">
<strong>Transcription:</strong>
<span id="transcriptionText"></span>
</div>
<div id="output" class="animate__animated animate__fadeIn">
<strong>Response:</strong><br>
<span id="responseText"></span>
</div>
<div id="videoOutput" class="animate__animated animate__fadeIn">
<video id="avatarVideo" controls></video>
</div>
</div>
<script>
let mediaRecorder;
let audioChunks = [];
let lastResponse = "";
let accumulatedText = "";
let eventSource = null;
let audioURL;
const recordButton = document.getElementById('recordButton');
const stopButton = document.getElementById('stopButton');
const transcriptionTextSpan = document.getElementById('transcriptionText');
const responseTextSpan = document.getElementById('responseText');
const avatarVideoPlayer = document.getElementById('avatarVideo');
const textInputField = document.getElementById('textInput');
recordButton.onclick = async () => {
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
mediaRecorder = new MediaRecorder(stream);
audioChunks = [];
mediaRecorder.ondataavailable = event => {
audioChunks.push(event.data);
};
mediaRecorder.onstop = async () => {
const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
const formData = new FormData();
formData.append('audio', audioBlob, 'recording.wav');
transcriptionTextSpan.innerText = "Transcribing...";
responseTextSpan.innerText = "";
avatarVideoPlayer.src = "";
try {
const sttResponse = await fetch('/api/v1/stt', {
method: 'POST',
body: formData
});
if (!sttResponse.ok) {
throw new Error(\`HTTP error! status: ${sttResponse.status}\`);
}
const sttData = await sttResponse.json();
const transcribedText = sttData.text;
transcriptionTextSpan.innerText = transcribedText || "Transcription failed.";
if (transcribedText) {
await generateAvatarVideoResponse(transcribedText);
}
} catch (error) {
console.error("STT or subsequent error:", error);
transcriptionTextSpan.innerText = "Transcription error.";
responseTextSpan.innerText = "Error processing audio.";
} finally {
recordButton.disabled = false;
stopButton.disabled = true;
}
};
recordButton.disabled = true;
stopButton.disabled = false;
transcriptionTextSpan.innerText = "Recording...";
mediaRecorder.start();
} catch (error) {
console.error("Error accessing microphone:", error);
transcriptionTextSpan.innerText = "Microphone access denied or error.";
recordButton.disabled = false;
stopButton.disabled = true;
}
};
stopButton.onclick = () => {
if (mediaRecorder && mediaRecorder.state === "recording") {
transcriptionTextSpan.innerText = "Processing...";
mediaRecorder.stop();
recordButton.disabled = true;
stopButton.disabled = true;
}
};
async function generateResponse() {
const inputText = textInputField.value;
if (!inputText.trim()) {
alert("Please enter text or record audio.");
return;
}
transcriptionTextSpan.innerText = inputText;
await generateAvatarVideoResponse(inputText);
}
async function generateAvatarVideoResponse(inputText) {
responseTextSpan.innerText = "Generating response...";
avatarVideoPlayer.src = "";
accumulatedText = "";
lastResponse = "";
const temp = 0.7;
const top_k_val = 40;
const top_p_val = 0.0;
const repetition_penalty_val = 1.2;
const requestData = {
text: inputText,
temp: temp,
top_k: top_k_val,
top_p: top_p_val,
reppenalty: repetition_penalty_val
};
if (eventSource) {
eventSource.close();
}
eventSource = new EventSource('/api/v1/generate_stream?' + new URLSearchParams(requestData).toString());
eventSource.onmessage = async function(event) {
if (event.data === "<END_STREAM>") {
eventSource.close();
const currentResponse = accumulatedText.replace("<|endoftext|>", "").replace(/\s+(?=[.,,。])/g, '').trim();
if (currentResponse === lastResponse.trim()) {
accumulatedText = "**Response is repetitive. Please try again or rephrase your query.**";
} else {
lastResponse = currentResponse;
}
responseTextSpan.innerHTML = marked.parse(accumulatedText);
try {
const ttsResponse = await fetch('/api/v1/tts', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ text: currentResponse })
});
if (!ttsResponse.ok) {
throw new Error(\`TTS HTTP error! status: ${ttsResponse.status}\`);
}
const ttsBlob = await ttsResponse.blob();
audioURL = URL.createObjectURL(ttsBlob);
const sadTalkerResponse = await fetch('/api/v1/sadtalker', {
method: 'POST',
body: new URLSearchParams({
'source_image': './examples/source_image/full_body_female.png',
'driven_audio': audioURL,
'preprocess': 'full',
'still_mode': false,
'use_enhancer': true
})
});
if (!sadTalkerResponse.ok) {
throw new Error(\`SadTalker HTTP error! status: ${sadTalkerResponse.status}\`);
}
const sadTalkerData = await sadTalkerResponse.json();
const videoURL = sadTalkerData.video_url;
avatarVideoPlayer.src = videoURL;
} catch (ttsError) {
console.error("TTS or SadTalker error:", ttsError);
responseTextSpan.innerHTML += "<br><br>Error generating audio or video avatar.";
}
return;
}
accumulatedText += event.data;
let partialText = accumulatedText.replace("<|endoftext|>", "").replace(/\s+(?=[.,,。])/g, '').trim();
responseTextSpan.innerHTML = marked.parse(partialText);
};
eventSource.onerror = function(error) {
console.error("SSE error", error);
eventSource.close();
responseTextSpan.innerText = "Error generating response stream.";
};
const outputDiv = document.getElementById("output");
outputDiv.classList.add("show");
}
</script>
</body>
</html>
"""
feedback_queue = queue.Queue()
@app.route("/")
def index():
return html_code
@app.route("/api/v1/generate_stream", methods=["GET"])
def generate_stream():
text = request.args.get("text", "")
temp = float(request.args.get("temp", 0.7))
top_k = int(request.args.get("top_k", 40))
top_p = float(request.args.get("top_p", 0.0))
reppenalty = float(request.args.get("reppenalty", 1.2))
response_queue = queue.Queue()
reasoning_queue.put({
'text_input': text,
'temperature': temp,
'top_k': top_k,
'top_p': top_p,
'repetition_penalty': reppenalty,
'response_queue': response_queue
})
@stream_with_context
def event_stream():
while True:
output = response_queue.get()
if "error" in output:
yield "data: <ERROR>\n\n"
break
text_chunk = output.get("text")
if text_chunk:
for word in text_chunk.split(' '):
clean_word = word.strip()
if clean_word:
yield "data: " + clean_word + "\n\n"
yield "data: <END_STREAM>\n\n"
break
return Response(event_stream(), mimetype="text/event-stream")
@app.route("/api/v1/generate", methods=["POST"])
def generate():
data = request.get_json()
text = data.get("text", "")
temp = float(data.get("temp", 0.7))
top_k = int(data.get("top_k", 40))
top_p = float(data.get("top_p", 0.0))
reppenalty = float(data.get("reppenalty", 1.2))
response_queue = queue.Queue()
reasoning_queue.put({
'text_input': text,
'temperature': temp,
'top_k': top_k,
'top_p': top_p,
'repetition_penalty': reppenalty,
'response_queue': response_queue
})
output = response_queue.get()
if "error" in output:
return jsonify({"error": output["error"]}), 500
result_text = output.get("text", "").strip()
return jsonify({"response": result_text})
@app.route("/api/v1/feedback", methods=["POST"])
def feedback():
data = request.get_json()
feedback_text = data.get("feedback_text")
correct_category = data.get("correct_category")
if feedback_text and correct_category:
feedback_queue.put((feedback_text, correct_category))
return jsonify({"status": "feedback received"})
return jsonify({"status": "feedback failed"}), 400
@app.route("/api/v1/tts", methods=["POST"])
def tts_api():
return tts_route()
@app.route("/api/v1/stt", methods=["POST"])
def stt_api():
return stt_route()
@app.route("/api/v1/sentiment", methods=["POST"])
def sentiment_api():
return sentiment_route()
@app.route("/api/v1/imagegen", methods=["POST"])
def imagegen_api():
return imagegen_route()
@app.route("/api/v1/musicgen", methods=["POST"])
def musicgen_api():
return musicgen_route()
@app.route("/api/v1/translation", methods=["POST"])
def translation_api():
return translation_route()
@app.route("/api/v1/codegen", methods=["POST"])
def codegen_api():
return codegen_route()
@app.route("/api/v1/text_to_video", methods=["POST"])
def text_to_video_api():
return text_to_video_route()
@app.route("/api/v1/summarization", methods=["POST"])
def summarization_api():
return summarization_route()
@app.route("/api/v1/image_to_3d", methods=["POST"])
def image_to_3d_api():
return image_to_3d_route()
@app.route("/api/v1/sadtalker", methods=["POST"])
def sadtalker():
from sadtalker_api import router as sadtalker_router
return sadtalker_router.create_video()
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)