import spaces import gradio as gr import torch from diffusers import DiffusionPipeline from diffusers.utils import load_image, export_to_video import random import numpy as np from moviepy import ImageSequenceClip, AudioFileClip, VideoFileClip from PIL import Image, ImageOps import os # ============================================================ # 🔥 GLOBAL PERFORMANCE SETTINGS (H200 OPTIMIZED) # ============================================================ torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.set_grad_enabled(False) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) DEVICE = "cuda" DTYPE = torch.bfloat16 # ============================================================ # 🎯 DISTILLED SIGMAS # ============================================================ DISTILLED_SIGMA_VALUES = [ 1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875 ] # ============================================================ # 🚀 LOAD MODEL ON STARTUP (ONLY ONCE) # ============================================================ print("🚀 Loading LTX-2 Distilled on H200...") pipe = DiffusionPipeline.from_pretrained( "rootonchair/LTX-2-19b-distilled", custom_pipeline="multimodalart/ltx2-audio-to-video", torch_dtype=DTYPE, ) pipe.to(DEVICE) # Enable memory efficient attention try: pipe.enable_xformers_memory_efficient_attention() print("✅ xFormers enabled") except Exception: print("⚠️ xFormers not available") # Load & Fuse LoRA ONCE print("📦 Loading Detailer LoRA...") pipe.load_lora_weights( "Lightricks/LTX-2-19b-IC-LoRA-Detailer", adapter_name="detailer" ) pipe.fuse_lora(lora_scale=0.8) pipe.unload_lora_weights() print("🔥 Model fully loaded on CUDA.") # ============================================================ # 🎬 HELPER FUNCTIONS # ============================================================ def save_video(video_frames, audio_path=None, fps=24): output_filename = f"output_{random.randint(0, 100000)}.mp4" # Convert frames if isinstance(video_frames, list): if video_frames and isinstance(video_frames[0], list): frames = video_frames[0] else: frames = video_frames np_frames = [np.array(img) for img in frames] clip = ImageSequenceClip(np_frames, fps=fps) elif isinstance(video_frames, str): clip = VideoFileClip(video_frames) else: temp_path = "temp_video_no_audio.mp4" export_to_video(video_frames, temp_path, fps=fps) clip = VideoFileClip(temp_path) if audio_path: audio_clip = AudioFileClip(audio_path) if audio_clip.duration > clip.duration: audio_clip = audio_clip.subclipped(0, clip.duration) clip = clip.with_audio(audio_clip) audio_codec = "aac" else: audio_codec = None clip.write_videofile( output_filename, fps=fps, codec="libx264", audio_codec=audio_codec, logger=None ) clip.close() if audio_path: audio_clip.close() return output_filename def infer_aspect_ratio(image): resolutions = { "1:1": (512, 512), "16:9": (768, 512), "9:16": (512, 768) } width, height = image.size image_ratio = width / height aspect_ratios = { "1:1": 1.0, "16:9": 16 / 9, "9:16": 9 / 16 } closest_ratio = min( aspect_ratios.keys(), key=lambda k: abs(aspect_ratios[k] - image_ratio) ) return resolutions[closest_ratio] def process_image_for_aspect_ratio(image): target_w, target_h = infer_aspect_ratio(image) processed_img = ImageOps.fit( image, (target_w, target_h), method=Image.LANCZOS, centering=(0.5, 0.5) ) return processed_img, target_w, target_h def get_audio_duration(audio_path): if audio_path is None: return gr.update() try: audio_clip = AudioFileClip(audio_path) duration = audio_clip.duration audio_clip.close() capped = min(duration, 12.0) rounded = round(capped * 2) / 2 return gr.update(value=rounded) except: return gr.update() # ============================================================ # 🎥 GENERATION FUNCTION (GPU ONLY HERE) # ============================================================ @spaces.GPU(duration=85, size="xlarge") def generate( image_path, audio_path, prompt, negative_prompt, video_duration, seed, progress=gr.Progress(track_tqdm=True) ): if not image_path: raise gr.Error("Please provide an image.") if seed == -1: seed = random.randint(0, 1_000_000) generator = torch.Generator(device="cuda").manual_seed(seed) original_image = load_image(image_path) image, width, height = process_image_for_aspect_ratio(original_image) fps = 24.0 # If audio exists → override duration if audio_path: audio_clip = AudioFileClip(audio_path) video_duration = min(audio_clip.duration, 12.0) audio_clip.close() total_frames = int(video_duration * fps) base_block = round(total_frames / 8) * 8 num_frames = max(base_block + 1, 9) print(f"Seed: {seed} | {width}x{height} | Frames: {num_frames}") with torch.inference_mode(): if audio_path: video_output, _ = pipe( image=image, audio=audio_path, prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, num_frames=num_frames, frame_rate=fps, num_inference_steps=8, sigmas=DISTILLED_SIGMA_VALUES, guidance_scale=1.0, generator=generator, return_dict=False, ) else: video_output = pipe( image=image, prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, num_frames=num_frames, frame_rate=fps, num_inference_steps=8, sigmas=DISTILLED_SIGMA_VALUES, guidance_scale=1.0, generator=generator, return_dict=False, )[0] output_path = save_video(video_output, audio_path, fps=fps) return output_path, seed # ============================================================ # 🖥️ GRADIO UI # ============================================================ css = "#col-container { max-width: 800px; margin: 0 auto; }" with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# ⚡ LTX-2 Distilled Image-to-Video (Audio Optional)") with gr.Row(): with gr.Column(): input_image = gr.Image(type="filepath", height=300) input_audio = gr.Audio(type="filepath", label="Optional Audio") with gr.Column(): result_video = gr.Video() prompt = gr.Textbox( value="A person speaking naturally", lines=2 ) video_duration = gr.Slider(1.0, 12.0, step=0.5, value=4.0) with gr.Accordion("Advanced", open=False): negative_prompt = gr.Textbox( value="low quality, worst quality" ) seed = gr.Number(value=-1, precision=0) run_btn = gr.Button("Generate", variant="primary") used_seed = gr.Number(visible=False) input_audio.change( fn=get_audio_duration, inputs=[input_audio], outputs=[video_duration] ) run_btn.click( fn=generate, inputs=[ input_image, input_audio, prompt, negative_prompt, video_duration, seed ], outputs=[result_video, used_seed] ) if __name__ == "__main__": demo.queue().launch()