# app.py import spaces import gradio as gr import argparse import sys import os import random import subprocess from PIL import Image import numpy as np subprocess.run(['sh', './sky.sh']) sys.path.append("./SkyReels-V1") from skyreelsinfer import TaskType from skyreelsinfer.offload import OffloadConfig from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer from diffusers.utils import export_to_video import torch import logging torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = False torch.set_float32_matmul_precision("highest") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") logger = logging.getLogger(__name__) _predictor = None task_type = TaskType.I2V @spaces.GPU(duration=90) def init_predictor(): global _predictor global task_type logger = logging.getLogger(__name__) if _predictor is None: if task_type == TaskType.I2V: model_id = "Skywork/SkyReels-V1-Hunyuan-I2V" elif task_type == TaskType.T2V: model_id = "your_t2v_model_id" else: raise ValueError(f"Invalid task_type: {task_type}") _predictor = SkyReelsVideoSingleGpuInfer( task_type=task_type, model_id=model_id, quant_model=True, is_offload=True, offload_config=OffloadConfig( high_cpu_memory=True, parameters_level=True, compiler_transformer=False, ), ) _predictor.initialize() logger.info("Predictor initialized") else: logger.warning("Predictor already initialized (should be rare).") @spaces.GPU(duration=90) def generate_video(prompt, seed, image=None): global _predictor global task_type if seed == -1: random.seed() seed = int(random.randrange(4294967294)) kwargs = { "prompt": prompt, "height": 512, "width": 512, "num_frames": 97, "num_inference_steps": 30, "seed": seed, "guidance_scale": 6.0, "embedded_guidance_scale": 1.0, "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", "cfg_for": False, } if task_type == TaskType.I2V: assert image is not None, "Please input an image for I2V task." kwargs["image"] = Image.open(image) elif task_type == TaskType.T2V: pass else: raise ValueError(f"Invalid Tasktype") if _predictor is None: init_predictor() output = _predictor.infer(**kwargs) output = (output.cpu().numpy() * 255).astype(np.uint8) output = output.transpose(0, 2, 3, 4, 1) save_dir = f"./result/{task_type.name}" os.makedirs(save_dir, exist_ok=True) video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4" print(f"generate video, local path: {video_out_file}") export_to_video(output, video_out_file, fps=24) return video_out_file, kwargs def create_gradio_interface(): with gr.Blocks() as demo: with gr.Row(): with gr.Column(): image = gr.Image(label="Upload Image", type="filepath") prompt = gr.Textbox(label="Input Prompt") seed = gr.Number(label="Random Seed", value=-1) with gr.Column(): submit_button = gr.Button("Generate Video") output_video = gr.Video(label="Generated Video") output_params = gr.Textbox(label="Output Parameters") submit_button.click( fn=generate_video, inputs=[prompt, seed, image], outputs=[output_video, output_params], ) return demo if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--task_type", type=str, default="i2v", choices=["t2v", "i2v"], help="Task type, 't2v' for text-to-video, 'i2v' for image-to-video.") args = parser.parse_args() if args.task_type == "t2v": task_type = TaskType.T2V elif args.task_type == "i2v": task_type = TaskType.I2V demo = create_gradio_interface() demo.queue().launch()