Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import tempfile | |
| import warnings | |
| import os | |
| warnings.filterwarnings("ignore") | |
| from sam_audio import SAMAudio, SAMAudioProcessor | |
| # Available models | |
| MODELS = { | |
| "sam-audio-small": "facebook/sam-audio-small", | |
| "sam-audio-base": "facebook/sam-audio-base", | |
| "sam-audio-large": "facebook/sam-audio-large", | |
| "sam-audio-small-tv (Visual)": "facebook/sam-audio-small-tv", | |
| "sam-audio-base-tv (Visual)": "facebook/sam-audio-base-tv", | |
| "sam-audio-large-tv (Visual)": "facebook/sam-audio-large-tv", | |
| } | |
| DEFAULT_MODEL = "sam-audio-small" | |
| EXAMPLES_DIR = "examples" | |
| EXAMPLE_FILE = os.path.join(EXAMPLES_DIR, "office.mp4") | |
| # Global model cache | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| current_model_name = None | |
| model = None | |
| processor = None | |
| def load_model(model_name): | |
| global current_model_name, model, processor | |
| model_id = MODELS.get(model_name, MODELS[DEFAULT_MODEL]) | |
| if current_model_name == model_name and model is not None: | |
| return | |
| print(f"Loading {model_id}...") | |
| model = SAMAudio.from_pretrained(model_id).to(device).eval() | |
| processor = SAMAudioProcessor.from_pretrained(model_id) | |
| current_model_name = model_name | |
| print(f"Model {model_id} loaded on {device}.") | |
| load_model(DEFAULT_MODEL) | |
| def save_audio(tensor, sample_rate): | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| torchaudio.save(tmp.name, tensor, sample_rate) | |
| return tmp.name | |
| def separate_audio(model_name, file_path, text_prompt): | |
| global model, processor | |
| load_model(model_name) | |
| if not file_path: | |
| return None, None, "β Please upload an audio or video file." | |
| if not text_prompt or not text_prompt.strip(): | |
| return None, None, "β Please enter a text prompt." | |
| try: | |
| inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device) | |
| with torch.inference_mode(): | |
| result = model.separate(inputs, predict_spans=False, reranking_candidates=1) | |
| sample_rate = processor.audio_sampling_rate | |
| target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate) | |
| residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate) | |
| return target_path, residual_path, f"β Isolated '{text_prompt}' using {model_name}" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, f"β Error: {str(e)}" | |
| def process_audio(model_name, audio_path, prompt): | |
| if not audio_path: | |
| return None, None, "β Please upload an audio file." | |
| return separate_audio(model_name, audio_path, prompt) | |
| def process_video(model_name, video_path, prompt): | |
| if not video_path: | |
| return None, None, "β Please upload a video file." | |
| return separate_audio(model_name, video_path, prompt) | |
| def process_example(model_name, prompt): | |
| if not os.path.exists(EXAMPLE_FILE): | |
| return None, None, "β Example file not found." | |
| return separate_audio(model_name, EXAMPLE_FILE, prompt) | |
| def load_example(prompt): | |
| return EXAMPLE_FILE, prompt | |
| # Build Interface | |
| with gr.Blocks(title="SAM-Audio Demo") as demo: | |
| gr.Markdown( | |
| """ | |
| # π΅ SAM-Audio: Segment Anything for Audio | |
| Isolate specific sounds from audio or video using natural language prompts. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_selector = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value=DEFAULT_MODEL, | |
| label="Model" | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem("π΅ Audio"): | |
| input_audio = gr.Audio(label="Upload Audio", type="filepath") | |
| with gr.TabItem("π¬ Video"): | |
| input_video = gr.Video(label="Upload Video") | |
| text_prompt = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="e.g., 'A man speaking', 'Piano', 'Dog barking'" | |
| ) | |
| run_btn = gr.Button("π― Isolate Sound", variant="primary") | |
| status_output = gr.Markdown("") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Results") | |
| output_target = gr.Audio(label="Isolated Sound (Target)") | |
| output_residual = gr.Audio(label="Background (Residual)") | |
| gr.Markdown("---") | |
| gr.Markdown("### π¬ Demo Examples (click to auto-process)") | |
| with gr.Row(): | |
| if os.path.exists(EXAMPLE_FILE): | |
| example_btn1 = gr.Button("π€ Man Speaking") | |
| example_btn2 = gr.Button("π€ Woman Speaking") | |
| example_btn3 = gr.Button("π΅ Background Music") | |
| # Audio processing | |
| run_btn.click( | |
| fn=lambda m, a, v, p: process_audio(m, a, p) if a else process_video(m, v, p), | |
| inputs=[model_selector, input_audio, input_video, text_prompt], | |
| outputs=[output_target, output_residual, status_output] | |
| ) | |
| # Example buttons | |
| if os.path.exists(EXAMPLE_FILE): | |
| example_btn1.click( | |
| fn=lambda: (EXAMPLE_FILE, "A man speaking"), | |
| outputs=[input_video, text_prompt] | |
| ).then( | |
| fn=lambda m: process_example(m, "A man speaking"), | |
| inputs=[model_selector], | |
| outputs=[output_target, output_residual, status_output] | |
| ) | |
| example_btn2.click( | |
| fn=lambda: (EXAMPLE_FILE, "A woman speaking"), | |
| outputs=[input_video, text_prompt] | |
| ).then( | |
| fn=lambda m: process_example(m, "A woman speaking"), | |
| inputs=[model_selector], | |
| outputs=[output_target, output_residual, status_output] | |
| ) | |
| example_btn3.click( | |
| fn=lambda: (EXAMPLE_FILE, "Background music"), | |
| outputs=[input_video, text_prompt] | |
| ).then( | |
| fn=lambda m: process_example(m, "Background music"), | |
| inputs=[model_selector], | |
| outputs=[output_target, output_residual, status_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |