Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,236 Bytes
227bc73 ecea5f9 227bc73 0e0805e 227bc73 ecea5f9 227bc73 ecea5f9 227bc73 f9a089d 06d7801 827103d ad2ae6c 4d4355a ecea5f9 8f1996d ecea5f9 b113647 227bc73 ecea5f9 227bc73 ecea5f9 227bc73 ecea5f9 227bc73 ada5e6d 0e0805e c0db3ab b113647 ecea5f9 4d4355a ecea5f9 4d4355a ecea5f9 c0db3ab ecea5f9 b113647 227bc73 ecea5f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import spaces
import gradio as gr
import argparse
import sys
import time
import os
import random
#sys.path.append("..")
from skyreelsinfer import TaskType
from skyreelsinfer.offload import OffloadConfig
from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
from diffusers.utils import export_to_video
from diffusers.utils import load_image
import torch
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
torch.set_float32_matmul_precision("highest")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
predictor = None
task_type = None
def get_transformer_model_id(task_type:str) -> str:
return "Skywork/SkyReels-V1-Hunyuan-I2V" if task_type == "i2v" else "Skywork/SkyReels-V1-Hunyuan-T2V"
def init_predictor(task_type:str, gpu_num:int=1):
global predictor
predictor = SkyReelsVideoInfer(
task_type= TaskType.I2V if task_type == "i2v" else TaskType.T2V,
model_id=get_transformer_model_id(task_type),
quant_model=True,
world_size=gpu_num,
is_offload=True,
offload_config=OffloadConfig(
high_cpu_memory=True,
parameters_level=True,
compiler_transformer=False,
)
)
def generate_video(prompt, seed, image=None):
global task_type
print(f"image:{type(image)}")
if seed == -1:
random.seed(time.time())
seed = int(random.randrange(4294967294))
kwargs = {
"prompt": prompt,
"height": 512,
"width": 512,
"num_frames": 97,
"num_inference_steps": 30,
"seed": seed,
"guidance_scale": 6.0,
"embedded_guidance_scale": 1.0,
"negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
"cfg_for": False,
}
if task_type == "i2v":
assert image is not None, "please input image"
kwargs["image"] = load_image(image=image)
global predictor
output = predictor.inference(kwargs)
save_dir = f"./result/{task_type}"
os.makedirs(save_dir, exist_ok=True)
video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
print(f"generate video, local path: {video_out_file}")
export_to_video(output, video_out_file, fps=24)
return video_out_file, kwargs
def create_gradio_interface(task_type):
"""Create a Gradio interface based on the task type."""
if task_type == "i2v":
with gr.Blocks() as demo:
with gr.Row():
image = gr.Image(label="Upload Image", type="filepath")
prompt = gr.Textbox(label="Input Prompt")
seed = gr.Number(label="Random Seed", value=-1)
submit_button = gr.Button("Generate Video")
output_video = gr.Video(label="Generated Video")
output_params = gr.Textbox(label="Output Parameters")
# Submit button logic
submit_button.click(
fn=generate_video,
inputs=[prompt, seed, image],
outputs=[output_video, output_params],
)
elif task_type == "t2v":
with gr.Blocks() as demo:
with gr.Row():
prompt = gr.Textbox(label="Input Prompt")
seed = gr.Number(label="Random Seed", value=-1)
submit_button = gr.Button("Generate Video")
output_video = gr.Video(label="Generated Video")
output_params = gr.Textbox(label="Output Parameters")
# Submit button logic
submit_button.click(
fn=generate_video,
inputs=[prompt, seed],
outputs=[output_video, output_params], # Pass task_type as additional input
)
return demo
if __name__ == "__main__":
# Parse command-line arguments
init_predictor(task_type="i2v", gpu_num=1)
demo = create_gradio_interface("i2v")
demo.launch() |