File size: 3,862 Bytes
09cb5f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import time
import torch
import gradio as gr
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
from dfloat11 import DFloat11Model
import spaces
import uuid

# Set environment variables
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"

# Ensure this runs on CPU or ZeroGPU
@spaces.GPU(enable_queue=True)
def generate_video(prompt, negative_prompt, width, height, num_frames,
                   guidance_scale, guidance_scale_2, num_inference_steps, fps):
    torch.cuda.empty_cache()
    start_time = time.time()

    # Load model
    vae = AutoencoderKLWan.from_pretrained(
        "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
        subfolder="vae",
        torch_dtype=torch.float32
    )

    pipe = WanPipeline.from_pretrained(
        "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
        vae=vae,
        torch_dtype=torch.bfloat16
    )

    # Load DFloat11 optimization layers
    DFloat11Model.from_pretrained(
        "DFloat11/Wan2.2-T2V-A14B-DF11",
        device="cpu",
        cpu_offload=True,
        bfloat16_model=pipe.transformer,
    )
    DFloat11Model.from_pretrained(
        "DFloat11/Wan2.2-T2V-A14B-2-DF11",
        device="cpu",
        cpu_offload=True,
        bfloat16_model=pipe.transformer_2,
    )

    pipe.enable_model_cpu_offload()

    # Run inference
    result = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=height,
        width=width,
        num_frames=num_frames,
        guidance_scale=guidance_scale,
        guidance_scale_2=guidance_scale_2,
        num_inference_steps=num_inference_steps,
    ).frames[0]

    output_path = f"/tmp/video_{uuid.uuid4().hex}.mp4"
    export_to_video(result, output_path, fps=fps)

    elapsed = time.time() - start_time
    print(f"Video generated in {elapsed:.2f} seconds")

    return output_path


# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 🎥 Wan2.2 Text-to-Video Generator (ZeroGPU Ready)")

    with gr.Row():
        prompt = gr.Textbox(
            label="Prompt",
            value="A serene koi pond at night, with glowing lanterns reflecting on the rippling water. Ethereal fireflies dance above as cherry blossoms gently fall.",
            lines=3
        )
        negative_prompt = gr.Textbox(
            label="Negative Prompt",
            value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
            lines=3
        )

    with gr.Row():
        width = gr.Slider(256, 1280, value=768, step=64, label="Width")
        height = gr.Slider(256, 720, value=432, step=64, label="Height")
        num_frames = gr.Slider(8, 81, value=40, step=1, label="Number of Frames")
        fps = gr.Slider(8, 30, value=16, step=1, label="FPS")

    with gr.Row():
        guidance_scale = gr.Slider(1.0, 10.0, value=4.0, step=0.1, label="Guidance Scale")
        guidance_scale_2 = gr.Slider(1.0, 10.0, value=3.0, step=0.1, label="Guidance Scale 2")
        num_inference_steps = gr.Slider(10, 60, value=40, step=1, label="Inference Steps")

    with gr.Row():
        btn = gr.Button("🎬 Generate Video")
        output_video = gr.Video(label="Generated Video")

    btn.click(
        generate_video,
        inputs=[prompt, negative_prompt, width, height, num_frames, guidance_scale, guidance_scale_2, num_inference_steps, fps],
        outputs=[output_video]
    )

# Launch Gradio app
demo.launch()