Spaces:
Running
Running
import os | |
import tempfile | |
import uuid | |
import asyncio | |
import shutil | |
import requests | |
from urllib.parse import urlparse | |
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, WebSocket | |
from fastapi.responses import JSONResponse | |
from fastapi import APIRouter | |
from extensions import * | |
from main import * | |
from tts_api import * | |
from sadtalker_utils import * | |
import base64 | |
from stt_api import * | |
from text_generation import * | |
router = APIRouter() | |
async def create_video( | |
source_image: str = Form(None), | |
source_image_file: UploadFile = File(None), | |
driven_audio: str = Form(None), | |
driven_audio_file: UploadFile = File(None), | |
preprocess: str = Form('crop'), | |
still_mode: bool = Form(False), | |
use_enhancer: bool = Form(False), | |
batch_size: int = Form(1), | |
size: int = Form(256), | |
pose_style: int = Form(0), | |
exp_scale: float = Form(1.0), | |
use_ref_video: bool = Form(False), | |
ref_video: str = Form(None), | |
ref_video_file: UploadFile = File(None), | |
ref_info: str = Form(None), | |
use_idle_mode: bool = Form(False), | |
length_of_audio: int = Form(0), | |
use_blink: bool = Form(True), | |
checkpoint_dir: str = Form('checkpoints'), | |
config_dir: str = Form('src/config'), | |
old_version: bool = Form(False), | |
tts_text: str = Form(None), | |
tts_lang: str = Form('en'), | |
): | |
if source_image_file and source_image: | |
raise HTTPException(status_code=400, detail="source_image and source_image_file cannot be both not None") | |
if driven_audio and driven_audio_file: | |
raise HTTPException(status_code=400, detail="driven_audio and driven_audio_file cannot be both not None") | |
if ref_video and ref_video_file: | |
raise HTTPException(status_code=400, detail="ref_video and ref_video_file cannot be both not None") | |
tmp_source_image = None | |
if source_image_file: | |
tmp_source_image = tempfile.NamedTemporaryFile(suffix=os.path.splitext(source_image_file.filename)[1], delete=False) | |
content = await source_image_file.read() | |
tmp_source_image.write(content) | |
source_image_path = tmp_source_image.name | |
elif source_image: | |
if urlparse(source_image).scheme in ["http", "https"]: | |
response = requests.get(source_image, stream=True) | |
response.raise_for_status() | |
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_source_image: | |
for chunk in response.iter_content(chunk_size=8192): | |
tmp_source_image.write(chunk) | |
source_image_path = tmp_source_image.name | |
else: | |
source_image_path = source_image | |
else: | |
raise HTTPException(status_code=400, detail="source_image not provided") | |
tmp_driven_audio = None | |
if driven_audio_file: | |
tmp_driven_audio = tempfile.NamedTemporaryFile(suffix=os.path.splitext(driven_audio_file.filename)[1], delete=False) | |
content = await driven_audio_file.read() | |
tmp_driven_audio.write(content) | |
driven_audio_path = tmp_driven_audio.name | |
elif driven_audio: | |
if urlparse(driven_audio).scheme in ["http", "https"]: | |
response = requests.get(driven_audio, stream=True) | |
response.raise_for_status() | |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_driven_audio: | |
for chunk in response.iter_content(chunk_size=8192): | |
tmp_driven_audio.write(chunk) | |
driven_audio_path = tmp_driven_audio.name | |
else: | |
driven_audio_path = driven_audio | |
else: | |
driven_audio_path = None | |
tmp_ref_video = None | |
if ref_video_file: | |
tmp_ref_video = tempfile.NamedTemporaryFile(suffix=os.path.splitext(ref_video_file.filename)[1], delete=False) | |
content = await ref_video_file.read() | |
tmp_ref_video.write(content) | |
ref_video_path = tmp_ref_video.name | |
elif ref_video: | |
if urlparse(ref_video).scheme in ["http", "https"]: | |
response = requests.get(ref_video, stream=True) | |
response.raise_for_status() | |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_ref_video: | |
for chunk in response.iter_content(chunk_size=8192): | |
tmp_ref_video.write(chunk) | |
ref_video_path = tmp_ref_video.name | |
else: | |
ref_video_path = ref_video | |
else: | |
ref_video_path=None | |
try: | |
loop = asyncio.get_running_loop() | |
output_path = await loop.run_in_executor(None, sadtalker_instance.test, | |
source_image_path, | |
driven_audio_path, | |
preprocess, | |
still_mode, | |
use_enhancer, | |
batch_size, | |
size, | |
pose_style, | |
exp_scale, | |
use_ref_video, | |
ref_video_path, | |
ref_info, | |
use_idle_mode, | |
length_of_audio, | |
use_blink, | |
'./results/', | |
tts_text=tts_text, | |
tts_lang=tts_lang, | |
) | |
return {"video_url": output_path} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
finally: | |
if tmp_source_image: | |
os.remove(tmp_source_image.name) | |
if tmp_driven_audio: | |
os.remove(tmp_driven_audio.name) | |
if tmp_ref_video: | |
os.remove(tmp_ref_video.name) | |
async def websocket_endpoint(websocket: WebSocket): | |
await websocket.accept() | |
tts_model = TTSTalker() | |
try: | |
while True: | |
data = await websocket.receive_json() | |
text = data.get("text") | |
audio_base64 = data.get("audio") | |
if text: | |
audio_path = await asyncio.get_running_loop().run_in_executor(None, tts_model.test, text) | |
elif audio_base64: | |
try: | |
audio_bytes = base64.b64decode(audio_base64) | |
tmp_audio_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
tmp_audio_file.write(audio_bytes) | |
audio_path = tmp_audio_file.name | |
transcription_text_file = speech_to_text_func(tmp_audio_file.name) | |
with open(transcription_text_file, 'r') as f: | |
transcription_text = f.read() | |
response_stream = perform_reasoning_stream(transcription_text, 0.7, 40, 0.0, 1.2) | |
response_text = "" | |
for chunk in response_stream: | |
if chunk == "<END_STREAM>": | |
break | |
response_text += chunk | |
audio_path = await asyncio.get_running_loop().run_in_executor(None, tts_model.test, response_text) | |
except Exception as e: | |
await websocket.send_json({"error":str(e)}) | |
continue | |
finally: | |
if 'tmp_audio_file' in locals() and tmp_audio_file: | |
os.remove(tmp_audio_file.name) | |
else: | |
continue | |
source_image_path = './examples/source_image/cyarh.png' | |
ref_video_path='./examples/driven_video/vid_xdd.mp4' | |
loop = asyncio.get_running_loop() | |
output = await loop.run_in_executor(None, sadtalker_instance.test, | |
source_image_path, | |
audio_path, | |
'full', | |
True, | |
True, | |
1, | |
256, | |
0, | |
1, | |
True, | |
ref_video_path, | |
"pose+blink", | |
False, | |
0, | |
True, | |
'./results/' | |
) | |
await websocket.send_json({"video_url": output}) | |
except Exception as e: | |
print(e) | |
await websocket.send_json({"error":str(e)}) | |
router = APIRouter() | |
router.add_api_route("/sadtalker", create_video, methods=["POST"]) | |
router.add_api_websocket_route("/ws", websocket_endpoint) | |