File size: 4,236 Bytes
227bc73
ecea5f9
227bc73
0e0805e
227bc73
ecea5f9
227bc73
 
ecea5f9
 
 
 
227bc73
f9a089d
06d7801
827103d
 
 
 
 
 
 
 
 
ad2ae6c
4d4355a
ecea5f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f1996d
ecea5f9
 
 
b113647
227bc73
ecea5f9
227bc73
ecea5f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227bc73
ecea5f9
227bc73
ada5e6d
0e0805e
c0db3ab
b113647
ecea5f9
 
 
 
 
4d4355a
 
 
ecea5f9
 
 
 
 
 
 
 
 
 
4d4355a
ecea5f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0db3ab
ecea5f9
b113647
227bc73
ecea5f9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import spaces

import gradio as gr
import argparse
import sys
import time
import os
import random
#sys.path.append("..")
from skyreelsinfer import TaskType
from skyreelsinfer.offload import OffloadConfig
from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
from diffusers.utils import export_to_video
from diffusers.utils import load_image

import torch

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
torch.set_float32_matmul_precision("highest")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

predictor = None
task_type = None

def get_transformer_model_id(task_type:str) -> str:
    return "Skywork/SkyReels-V1-Hunyuan-I2V" if task_type == "i2v" else "Skywork/SkyReels-V1-Hunyuan-T2V"

def init_predictor(task_type:str, gpu_num:int=1):
    global predictor
    predictor = SkyReelsVideoInfer(
        task_type= TaskType.I2V if task_type == "i2v" else TaskType.T2V,
        model_id=get_transformer_model_id(task_type),
        quant_model=True,
        world_size=gpu_num,
        is_offload=True,
        offload_config=OffloadConfig(
            high_cpu_memory=True,
            parameters_level=True,
            compiler_transformer=False,
        )
    )

def generate_video(prompt, seed, image=None):
    global task_type
    print(f"image:{type(image)}")

    if seed == -1:
        random.seed(time.time())
        seed = int(random.randrange(4294967294))
    
    kwargs = {
        "prompt": prompt,
        "height": 512,
        "width": 512,
        "num_frames": 97,
        "num_inference_steps": 30,
        "seed": seed,
        "guidance_scale": 6.0,
        "embedded_guidance_scale": 1.0,
        "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
        "cfg_for": False,
    }

    if task_type == "i2v":
        assert image is not None, "please input image"
        kwargs["image"] = load_image(image=image)
    global predictor
    output = predictor.inference(kwargs)
    save_dir = f"./result/{task_type}"
    os.makedirs(save_dir, exist_ok=True)
    video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
    print(f"generate video, local path: {video_out_file}")
    export_to_video(output, video_out_file, fps=24)
    return video_out_file, kwargs


def create_gradio_interface(task_type):
    """Create a Gradio interface based on the task type."""
    if task_type == "i2v":
        with gr.Blocks() as demo:
            with gr.Row():
                image = gr.Image(label="Upload Image", type="filepath")
                prompt = gr.Textbox(label="Input Prompt")
                seed = gr.Number(label="Random Seed", value=-1)
            submit_button = gr.Button("Generate Video")
            output_video = gr.Video(label="Generated Video")
            output_params = gr.Textbox(label="Output Parameters")

            # Submit button logic
            submit_button.click(
                fn=generate_video,
                inputs=[prompt, seed, image],
                outputs=[output_video, output_params],
            )

    elif task_type == "t2v":
        with gr.Blocks() as demo:
            with gr.Row():
                prompt = gr.Textbox(label="Input Prompt")
                seed = gr.Number(label="Random Seed", value=-1)
            submit_button = gr.Button("Generate Video")
            output_video = gr.Video(label="Generated Video")
            output_params = gr.Textbox(label="Output Parameters")

            # Submit button logic
            submit_button.click(
                fn=generate_video,
                inputs=[prompt, seed],
                outputs=[output_video, output_params],  # Pass task_type as additional input
            )

    return demo

if __name__ == "__main__":
    # Parse command-line arguments
    init_predictor(task_type="i2v", gpu_num=1)
    demo = create_gradio_interface("i2v")
    demo.launch()