import gradio as gr import torch import torchaudio import tempfile try: import spaces except ImportError: class spaces: @staticmethod def GPU(duration=60): def decorator(func): return func return decorator from sam_audio import SAMAudio, SAMAudioProcessor # Configuration MODEL_NAME = "facebook/sam-audio-small" device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading {MODEL_NAME} on {device}...") # Load Model and Processor try: model = SAMAudio.from_pretrained(MODEL_NAME).to(device).eval() processor = SAMAudioProcessor.from_pretrained(MODEL_NAME) print("Model loaded successfully.") except Exception as e: print(f"Error loading model. Did you set HF_TOKEN in secrets? Error: {e}") raise e def save_audio(tensor, sample_rate): """Helper to save torch tensor to a temp file for Gradio output.""" if tensor.dim() == 1: tensor = tensor.unsqueeze(0) tensor = tensor.detach().cpu() with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: torchaudio.save(tmp.name, tensor, sample_rate) return tmp.name @spaces.GPU(duration=120) def separate_audio(audio_path, text_prompt): if not audio_path: return None, None # Process Inputs inputs = processor( audios=[audio_path], descriptions=[text_prompt] ).to(device) # Inference with torch.no_grad(): result = model.separate(inputs) # Extract Outputs target_audio = result.target[0] # The sound you asked for residual_audio = result.residual[0] # Everything else # Get sampling rate from the processor config sr = processor.feature_extractor.sampling_rate # Save to files target_path = save_audio(target_audio, sr) residual_path = save_audio(residual_audio, sr) return target_path, residual_path # Build Gradio Interface with gr.Blocks(title="SAM-Audio Demo") as demo: gr.Markdown( """ # 🎵 SAM-Audio: Segment Anything for Audio Isolate specific sounds from an audio file using natural language prompts. **Model:** `facebook/sam-audio-small` """ ) with gr.Row(): with gr.Column(): input_audio = gr.Audio(label="Upload Input Audio", type="filepath") text_prompt = gr.Textbox( label="Text Prompt", placeholder="e.g., 'dog barking', 'man speaking', 'typing keyboard'", info="Describe the sound you want to isolate." ) run_btn = gr.Button("Separate Audio", variant="primary") with gr.Column(): output_target = gr.Audio(label="Isolated Sound (Target)") output_residual = gr.Audio(label="Background (Residual)") run_btn.click( fn=separate_audio, inputs=[input_audio, text_prompt], outputs=[output_target, output_residual] ) # Launch demo.queue().launch()