import spaces import gradio as gr import torch import torchaudio import tempfile import warnings import os warnings.filterwarnings("ignore") from sam_audio import SAMAudio, SAMAudioProcessor # Available models MODELS = { "sam-audio-small": "facebook/sam-audio-small", "sam-audio-base": "facebook/sam-audio-base", "sam-audio-large": "facebook/sam-audio-large", "sam-audio-small-tv (Visual)": "facebook/sam-audio-small-tv", "sam-audio-base-tv (Visual)": "facebook/sam-audio-base-tv", "sam-audio-large-tv (Visual)": "facebook/sam-audio-large-tv", } DEFAULT_MODEL = "sam-audio-small" EXAMPLES_DIR = "examples" EXAMPLE_FILE = os.path.join(EXAMPLES_DIR, "office.mp4") # Global model cache device = torch.device("cuda" if torch.cuda.is_available() else "cpu") current_model_name = None model = None processor = None def load_model(model_name): global current_model_name, model, processor model_id = MODELS.get(model_name, MODELS[DEFAULT_MODEL]) if current_model_name == model_name and model is not None: return print(f"Loading {model_id}...") model = SAMAudio.from_pretrained(model_id).to(device).eval() processor = SAMAudioProcessor.from_pretrained(model_id) current_model_name = model_name print(f"Model {model_id} loaded on {device}.") load_model(DEFAULT_MODEL) def save_audio(tensor, sample_rate): with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: torchaudio.save(tmp.name, tensor, sample_rate) return tmp.name @spaces.GPU(duration=300) def separate_audio(model_name, file_path, text_prompt): global model, processor load_model(model_name) if not file_path: return None, None, "❌ Please upload an audio or video file." if not text_prompt or not text_prompt.strip(): return None, None, "❌ Please enter a text prompt." try: inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device) with torch.inference_mode(): result = model.separate(inputs, predict_spans=False, reranking_candidates=1) sample_rate = processor.audio_sampling_rate target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate) residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate) return target_path, residual_path, f"✅ Isolated '{text_prompt}' using {model_name}" except Exception as e: import traceback traceback.print_exc() return None, None, f"❌ Error: {str(e)}" def process_audio(model_name, audio_path, prompt): if not audio_path: return None, None, "❌ Please upload an audio file." return separate_audio(model_name, audio_path, prompt) def process_video(model_name, video_path, prompt): if not video_path: return None, None, "❌ Please upload a video file." return separate_audio(model_name, video_path, prompt) def process_example(model_name, prompt): if not os.path.exists(EXAMPLE_FILE): return None, None, "❌ Example file not found." return separate_audio(model_name, EXAMPLE_FILE, prompt) def load_example(prompt): return EXAMPLE_FILE, prompt # Build Interface with gr.Blocks(title="SAM-Audio Demo") as demo: gr.Markdown( """ # 🎵 SAM-Audio: Segment Anything for Audio Isolate specific sounds from audio or video using natural language prompts. """ ) with gr.Row(): with gr.Column(scale=1): model_selector = gr.Dropdown( choices=list(MODELS.keys()), value=DEFAULT_MODEL, label="Model" ) with gr.Tabs(): with gr.TabItem("🎵 Audio"): input_audio = gr.Audio(label="Upload Audio", type="filepath") with gr.TabItem("🎬 Video"): input_video = gr.Video(label="Upload Video") text_prompt = gr.Textbox( label="Text Prompt", placeholder="e.g., 'A man speaking', 'Piano', 'Dog barking'" ) run_btn = gr.Button("🎯 Isolate Sound", variant="primary") status_output = gr.Markdown("") with gr.Column(scale=1): gr.Markdown("### Results") output_target = gr.Audio(label="Isolated Sound (Target)") output_residual = gr.Audio(label="Background (Residual)") gr.Markdown("---") gr.Markdown("### 🎬 Demo Examples (click to auto-process)") with gr.Row(): if os.path.exists(EXAMPLE_FILE): example_btn1 = gr.Button("🎤 Man Speaking") example_btn2 = gr.Button("🎤 Woman Speaking") example_btn3 = gr.Button("🎵 Background Music") # Audio processing run_btn.click( fn=lambda m, a, v, p: process_audio(m, a, p) if a else process_video(m, v, p), inputs=[model_selector, input_audio, input_video, text_prompt], outputs=[output_target, output_residual, status_output] ) # Example buttons if os.path.exists(EXAMPLE_FILE): example_btn1.click( fn=lambda: (EXAMPLE_FILE, "A man speaking"), outputs=[input_video, text_prompt] ).then( fn=lambda m: process_example(m, "A man speaking"), inputs=[model_selector], outputs=[output_target, output_residual, status_output] ) example_btn2.click( fn=lambda: (EXAMPLE_FILE, "A woman speaking"), outputs=[input_video, text_prompt] ).then( fn=lambda m: process_example(m, "A woman speaking"), inputs=[model_selector], outputs=[output_target, output_residual, status_output] ) example_btn3.click( fn=lambda: (EXAMPLE_FILE, "Background music"), outputs=[input_video, text_prompt] ).then( fn=lambda m: process_example(m, "Background music"), inputs=[model_selector], outputs=[output_target, output_residual, status_output] ) if __name__ == "__main__": demo.launch()