multimodalart's picture
Create app.py
28b2dac verified
raw
history blame
11.2 kB
import gradio as gr
import torch
import os
import numpy as np
from PIL import Image
import cv2
import tempfile
import moviepy.editor as mp
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from diffusers.utils import export_to_video, load_image
# Import required modules from SkyReels
from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel
from skyreels_a1.skyreels_a1_i2v_pipeline import SkyReelsA1ImagePoseToVideoPipeline
from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor
from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor
from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d
from diffusers.models import AutoencoderKLCogVideoX
from transformers import SiglipImageProcessor, SiglipVisionModel
from diffposetalk.diffposetalk import DiffPoseTalk
from huggingface_hub import snapshot_download
os.system("pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt221/download.html")
os.makedirs("pretrained_models", exist_ok=True)
snapshot_download(
repo_id="multimodalart/diffposetalk",
local_dir="pretrained_models/diffposetalk"
)
snapshot_download(
repo_id="Skywork/SkyReels-A1",
local_dir="pretrained_models/FLAME",
allow_patterns="extra_models/FLAME/**"
)
snapshot_download(
repo_id="Skywork/SkyReels-A1",
local_dir="pretrained_models/mediapipe",
allow_patterns="extra_models/mediapipe/**"
)
snapshot_download(
repo_id="Skywork/SkyReels-A1",
local_dir="pretrained_models/smirk",
allow_patterns="extra_models/smirk/**"
)
snapshot_download(
repo_id="Skywork/SkyReels-A1",
local_dir="pretrained_models/SkyReels-A1-5B",
allow_patterns="SkyReels-A1-5B/**"
)
# Helper functions from the original script
def parse_video(driving_frames, max_frame_num, fps=25):
video_length = len(driving_frames)
duration = video_length / fps
target_times = np.arange(0, duration, 1/12)
frame_indices = (target_times * fps).astype(np.int32)
frame_indices = frame_indices[frame_indices < video_length]
new_driving_frames = []
for idx in frame_indices:
new_driving_frames.append(driving_frames[idx])
if len(new_driving_frames) >= max_frame_num - 1:
break
video_lenght_add = max_frame_num - len(new_driving_frames) - 1
new_driving_frames = [new_driving_frames[0]]*2 + new_driving_frames[1:len(new_driving_frames)-1] + [new_driving_frames[-1]] * video_lenght_add
return new_driving_frames
def write_mp4(video_path, samples, fps=12):
clip = mp.ImageSequenceClip(samples, fps=fps)
clip.write_videofile(video_path, audio_codec="aac", codec="libx264",
ffmpeg_params=["-crf", "18", "-preset", "slow"])
def save_video_with_audio(video_path, audio_path, save_path):
video_clip = mp.VideoFileClip(video_path)
audio_clip = mp.AudioFileClip(audio_path)
if audio_clip.duration > video_clip.duration:
audio_clip = audio_clip.subclip(0, video_clip.duration)
video_with_audio = video_clip.set_audio(audio_clip)
video_with_audio.write_videofile(save_path, fps=12, codec="libx264", audio_codec="aac")
# Clean up
video_clip.close()
audio_clip.close()
return save_path
# Global parameters
model_name = "pretrained_models/SkyReels-A1-5B/"
siglip_name = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384"
weight_dtype = torch.bfloat16
max_frame_num = 49
sample_size = [480, 720]
# Preload all models in global context
print("Loading models...")
# Load LMK extractor and processors
lmk_extractor = LMKExtractor()
processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt')
vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False)
face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1),
det_model='retinaface_resnet50', save_ext='png', device="cuda")
# Load siglip visual encoder
siglip = SiglipVisionModel.from_pretrained(siglip_name)
siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_name)
# Load diffposetalk
diffposetalk = DiffPoseTalk()
# Load SkyReels models
transformer = CogVideoXTransformer3DModel.from_pretrained(
model_name,
subfolder="transformer"
).to(weight_dtype)
vae = AutoencoderKLCogVideoX.from_pretrained(
model_name,
subfolder="vae"
).to(weight_dtype)
lmk_encoder = AutoencoderKLCogVideoX.from_pretrained(
model_name,
subfolder="pose_guider",
).to(weight_dtype)
# Set up pipeline
pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained(
model_name,
transformer=transformer,
vae=vae,
lmk_encoder=lmk_encoder,
image_encoder=siglip,
feature_extractor=siglip_normalize,
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
pipe.transformer = torch.compile(pipe.transformer)
pipe.vae = torch.compile(pipe.vae)
# pipe.enable_model_cpu_offload()
# pipe.vae.enable_tiling()
print("Models loaded successfully!")
def process_image_audio(image_path, audio_path, guidance_scale=3.0, steps=10, progress=gr.Progress()):
progress(0.1, desc="Processing inputs...")
# Create a directory for outputs if it doesn't exist
output_dir = "gradio_outputs"
os.makedirs(output_dir, exist_ok=True)
# Create temp files for processing
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video_file, \
tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_output_file:
temp_video_path = temp_video_file.name
final_output_path = temp_output_file.name
# Set seed
seed = 43
generator = torch.Generator(device="cuda").manual_seed(seed)
progress(0.2, desc="Processing image...")
# Load and process image
image = load_image(image=image_path)
image = processor.crop_and_resize(image, sample_size[0], sample_size[1])
# Crop face
ref_image, x1, y1 = processor.face_crop(np.array(image))
face_h, face_w, _ = ref_image.shape
source_image = ref_image
progress(0.3, desc="Processing facial landmarks...")
# Process source image
source_outputs, source_tform, image_original = processor.process_source_image(source_image)
progress(0.4, desc="Processing audio...")
# Process audio and generate driving outputs
driving_outputs = diffposetalk.infer_from_file(
audio_path,
source_outputs["shape_params"].view(-1)[:100].detach().cpu().numpy()
)
progress(0.5, desc="Processing landmarks from coefficients...")
# Process landmarks
out_frames = processor.preprocess_lmk3d_from_coef(
source_outputs, source_tform, image_original.shape, driving_outputs
)
out_frames = parse_video(out_frames, max_frame_num)
rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(48, axis=0)
for ii in range(rescale_motions.shape[0]):
rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii]
ref_image_resized = cv2.resize(ref_image, (512, 512))
ref_lmk = lmk_extractor(ref_image_resized[:, :, ::-1])
ref_img = vis.draw_landmarks_v3(
(512, 512), (face_w, face_h),
ref_lmk['lmks'].astype(np.float32), normed=True
)
first_motion = np.zeros_like(np.array(image))
first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img
first_motion = first_motion[np.newaxis, :]
motions = np.concatenate([first_motion, rescale_motions])
input_video = motions[:max_frame_num]
# Face alignment
face_helper.clean_all()
face_helper.read_image(np.array(image)[:, :, ::-1])
face_helper.get_face_landmarks_5(only_center_face=True)
face_helper.align_warp_face()
align_face = face_helper.cropped_faces[0]
image_face = align_face[:, :, ::-1]
# Prepare input video
input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0)
input_video = input_video / 255
progress(0.6, desc="Generating animation (this may take a while)...")
# Generate video
#with torch.no_grad():
sample = pipe(
image=image,
image_face=image_face,
control_video=input_video,
prompt="",
negative_prompt="",
height=480,
width=720,
num_frames=49,
generator=generator,
guidance_scale=guidance_scale,
num_inference_steps=steps,
)
out_samples = sample.frames[0]
out_samples = out_samples[2:] # Skip first two frames
progress(0.8, desc="Creating output video...")
# Export video
export_to_video(out_samples, temp_video_path, fps=12)
progress(0.9, desc="Adding audio to video...")
# Add audio to video
result_path = save_video_with_audio(temp_video_path, audio_path, final_output_path)
# Create side-by-side comparison
target_h, target_w = sample_size[0], sample_size[1]
final_images = []
for i in range(len(out_samples)):
frame1 = image
frame2 = Image.fromarray(np.array(out_samples[i])).convert("RGB")
result = Image.new('RGB', (target_w * 2, target_h))
result.paste(frame1, (0, 0))
result.paste(frame2, (target_w, 0))
final_images.append(np.array(result))
comparison_path = os.path.join(output_dir, "comparison.mp4")
write_mp4(comparison_path, final_images, fps=12)
# Add audio to comparison video
comparison_with_audio = os.path.join(output_dir, "comparison_with_audio.mp4")
comparison_with_audio = save_video_with_audio(comparison_path, audio_path, comparison_with_audio)
progress(1.0, desc="Done!")
return result_path, comparison_with_audio
# Create Gradio interface
with gr.Blocks(title="SkyReels A1 Face Animation") as app:
gr.Markdown("# SkyReels A1 Face Animation")
gr.Markdown("Upload a portrait image and an audio file to animate the face")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="filepath", label="Portrait Image")
audio_input = gr.Audio(type="filepath", label="Driving Audio")
with gr.Row():
guidance_scale = gr.Slider(minimum=1.0, maximum=7.0, value=3.0, step=0.1, label="Guidance Scale")
inference_steps = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Inference Steps")
generate_button = gr.Button("Generate Animation", variant="primary")
with gr.Column():
output_video = gr.Video(label="Animation Result")
comparison_video = gr.Video(label="Side-by-Side Comparison")
generate_button.click(
fn=process_image_audio,
inputs=[image_input, audio_input, guidance_scale, inference_steps],
outputs=[output_video, comparison_video]
)
gr.Markdown("""
## Instructions
1. Upload a portrait image (frontal face works best)
2. Upload an audio file (wav format recommended)
3. Adjust parameters if needed
4. Click "Generate Animation" to create the video
Note: Processing may take several minutes depending on your hardware.
""")
if __name__ == "__main__":
app.launch(share=True)