SkyReels / app.py
1inkusFace's picture
Update app.py
6746920 verified
raw
history blame
2.77 kB
import spaces
import gradio as gr
import argparse
import sys
import time
import os
import random
from skyreelsinfer.offload import OffloadConfig
from skyreelsinfer import TaskType
from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer
from diffusers.utils import export_to_video
from diffusers.utils import load_image
from PIL import Image
#predictor = None
#task_type = None
os.putenv("TOKENIZERS_PARALLELISM","False")
#@spaces.GPU(duration=120)
def init_predictor():
global predictor
predictor = SkyReelsVideoSingleGpuInfer(
task_type= TaskType.I2V,
model_id="Skywork/SkyReels-V1-Hunyuan-I2V",
quant_model=False,
is_offload=False,
offload_config=OffloadConfig(
high_cpu_memory=True,
parameters_level=True,
compiler_transformer=False,
)
)
@spaces.GPU(duration=80)
def generate_video(prompt, seed, image=None):
print(f"image:{type(image)}")
if seed == -1:
random.seed(time.time())
seed = int(random.randrange(4294967294))
kwargs = {
"prompt": prompt,
"height": 320,
"width": 320,
"num_frames": 64,
"num_inference_steps": 10,
"seed": seed,
"guidance_scale": 6.0,
"embedded_guidance_scale": 1.0,
"negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
"cfg_for": False,
}
assert image is not None, "please input image"
kwargs["image"] = load_image(image=image).resize((320,320), Image.LANCZOS)
#global predictor
output = predictor.inference(kwargs)
save_dir = f"./result"
os.makedirs(save_dir, exist_ok=True)
video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
print(f"generate video, local path: {video_out_file}")
export_to_video(output, video_out_file, fps=24)
return video_out_file
def create_gradio_interface():
with gr.Blocks() as demo:
with gr.Row():
image = gr.Image(label="Upload Image", type="filepath")
prompt = gr.Textbox(label="Input Prompt")
seed = gr.Number(label="Random Seed", value=-1)
submit_button = gr.Button("Generate Video")
output_video = gr.Video(label="Generated Video")
submit_button.click(
fn=generate_video,
inputs=[prompt, seed, image],
outputs=[output_video],
)
return demo
#init_predictor()
if __name__ == "__main__":
#import multiprocessing
#multiprocessing.freeze_support()
init_predictor()
demo = create_gradio_interface()
demo.launch()