LTX2_distill / app_noaud.py
rahul7star's picture
Create app_noaud.py
8f29086 verified
import spaces
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from diffusers.utils import load_image, export_to_video
import random
import numpy as np
from moviepy import ImageSequenceClip, AudioFileClip, VideoFileClip
from PIL import Image, ImageOps
import os
# ============================================================
# 🔥 GLOBAL PERFORMANCE SETTINGS (H200 OPTIMIZED)
# ============================================================
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_grad_enabled(False)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
DEVICE = "cuda"
DTYPE = torch.bfloat16
# ============================================================
# 🎯 DISTILLED SIGMAS
# ============================================================
DISTILLED_SIGMA_VALUES = [
1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875
]
# ============================================================
# 🚀 LOAD MODEL ON STARTUP (ONLY ONCE)
# ============================================================
print("🚀 Loading LTX-2 Distilled on H200...")
pipe = DiffusionPipeline.from_pretrained(
"rootonchair/LTX-2-19b-distilled",
custom_pipeline="multimodalart/ltx2-audio-to-video",
torch_dtype=DTYPE,
)
pipe.to(DEVICE)
# Enable memory efficient attention
try:
pipe.enable_xformers_memory_efficient_attention()
print("✅ xFormers enabled")
except Exception:
print("⚠️ xFormers not available")
# Load & Fuse LoRA ONCE
print("📦 Loading Detailer LoRA...")
pipe.load_lora_weights(
"Lightricks/LTX-2-19b-IC-LoRA-Detailer",
adapter_name="detailer"
)
pipe.fuse_lora(lora_scale=0.8)
pipe.unload_lora_weights()
print("🔥 Model fully loaded on CUDA.")
# ============================================================
# 🎬 HELPER FUNCTIONS
# ============================================================
def save_video(video_frames, audio_path=None, fps=24):
output_filename = f"output_{random.randint(0, 100000)}.mp4"
# Convert frames
if isinstance(video_frames, list):
if video_frames and isinstance(video_frames[0], list):
frames = video_frames[0]
else:
frames = video_frames
np_frames = [np.array(img) for img in frames]
clip = ImageSequenceClip(np_frames, fps=fps)
elif isinstance(video_frames, str):
clip = VideoFileClip(video_frames)
else:
temp_path = "temp_video_no_audio.mp4"
export_to_video(video_frames, temp_path, fps=fps)
clip = VideoFileClip(temp_path)
if audio_path:
audio_clip = AudioFileClip(audio_path)
if audio_clip.duration > clip.duration:
audio_clip = audio_clip.subclipped(0, clip.duration)
clip = clip.with_audio(audio_clip)
audio_codec = "aac"
else:
audio_codec = None
clip.write_videofile(
output_filename,
fps=fps,
codec="libx264",
audio_codec=audio_codec,
logger=None
)
clip.close()
if audio_path:
audio_clip.close()
return output_filename
def infer_aspect_ratio(image):
resolutions = {
"1:1": (512, 512),
"16:9": (768, 512),
"9:16": (512, 768)
}
width, height = image.size
image_ratio = width / height
aspect_ratios = {
"1:1": 1.0,
"16:9": 16 / 9,
"9:16": 9 / 16
}
closest_ratio = min(
aspect_ratios.keys(),
key=lambda k: abs(aspect_ratios[k] - image_ratio)
)
return resolutions[closest_ratio]
def process_image_for_aspect_ratio(image):
target_w, target_h = infer_aspect_ratio(image)
processed_img = ImageOps.fit(
image,
(target_w, target_h),
method=Image.LANCZOS,
centering=(0.5, 0.5)
)
return processed_img, target_w, target_h
def get_audio_duration(audio_path):
if audio_path is None:
return gr.update()
try:
audio_clip = AudioFileClip(audio_path)
duration = audio_clip.duration
audio_clip.close()
capped = min(duration, 12.0)
rounded = round(capped * 2) / 2
return gr.update(value=rounded)
except:
return gr.update()
# ============================================================
# 🎥 GENERATION FUNCTION (GPU ONLY HERE)
# ============================================================
@spaces.GPU(duration=85, size="xlarge")
def generate(
image_path,
audio_path,
prompt,
negative_prompt,
video_duration,
seed,
progress=gr.Progress(track_tqdm=True)
):
if not image_path:
raise gr.Error("Please provide an image.")
if seed == -1:
seed = random.randint(0, 1_000_000)
generator = torch.Generator(device="cuda").manual_seed(seed)
original_image = load_image(image_path)
image, width, height = process_image_for_aspect_ratio(original_image)
fps = 24.0
# If audio exists → override duration
if audio_path:
audio_clip = AudioFileClip(audio_path)
video_duration = min(audio_clip.duration, 12.0)
audio_clip.close()
total_frames = int(video_duration * fps)
base_block = round(total_frames / 8) * 8
num_frames = max(base_block + 1, 9)
print(f"Seed: {seed} | {width}x{height} | Frames: {num_frames}")
with torch.inference_mode():
if audio_path:
video_output, _ = pipe(
image=image,
audio=audio_path,
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_frames=num_frames,
frame_rate=fps,
num_inference_steps=8,
sigmas=DISTILLED_SIGMA_VALUES,
guidance_scale=1.0,
generator=generator,
return_dict=False,
)
else:
video_output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_frames=num_frames,
frame_rate=fps,
num_inference_steps=8,
sigmas=DISTILLED_SIGMA_VALUES,
guidance_scale=1.0,
generator=generator,
return_dict=False,
)[0]
output_path = save_video(video_output, audio_path, fps=fps)
return output_path, seed
# ============================================================
# 🖥️ GRADIO UI
# ============================================================
css = "#col-container { max-width: 800px; margin: 0 auto; }"
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# ⚡ LTX-2 Distilled Image-to-Video (Audio Optional)")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="filepath", height=300)
input_audio = gr.Audio(type="filepath", label="Optional Audio")
with gr.Column():
result_video = gr.Video()
prompt = gr.Textbox(
value="A person speaking naturally",
lines=2
)
video_duration = gr.Slider(1.0, 12.0, step=0.5, value=4.0)
with gr.Accordion("Advanced", open=False):
negative_prompt = gr.Textbox(
value="low quality, worst quality"
)
seed = gr.Number(value=-1, precision=0)
run_btn = gr.Button("Generate", variant="primary")
used_seed = gr.Number(visible=False)
input_audio.change(
fn=get_audio_duration,
inputs=[input_audio],
outputs=[video_duration]
)
run_btn.click(
fn=generate,
inputs=[
input_image,
input_audio,
prompt,
negative_prompt,
video_duration,
seed
],
outputs=[result_video, used_seed]
)
if __name__ == "__main__":
demo.queue().launch()