Spaces:
Runtime error
Runtime error
Peter Shi
feat: Migrate the deployment to the Gradio SDK, integrate the `spaces.GPU` decorator, and remove the Dockerfile.
1b3117a
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import tempfile | |
| try: | |
| import spaces | |
| except ImportError: | |
| class spaces: | |
| def GPU(duration=60): | |
| def decorator(func): | |
| return func | |
| return decorator | |
| from sam_audio import SAMAudio, SAMAudioProcessor | |
| # Configuration | |
| MODEL_NAME = "facebook/sam-audio-small" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading {MODEL_NAME} on {device}...") | |
| # Load Model and Processor | |
| try: | |
| model = SAMAudio.from_pretrained(MODEL_NAME).to(device).eval() | |
| processor = SAMAudioProcessor.from_pretrained(MODEL_NAME) | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model. Did you set HF_TOKEN in secrets? Error: {e}") | |
| raise e | |
| def save_audio(tensor, sample_rate): | |
| """Helper to save torch tensor to a temp file for Gradio output.""" | |
| if tensor.dim() == 1: | |
| tensor = tensor.unsqueeze(0) | |
| tensor = tensor.detach().cpu() | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| torchaudio.save(tmp.name, tensor, sample_rate) | |
| return tmp.name | |
| def separate_audio(audio_path, text_prompt): | |
| if not audio_path: | |
| return None, None | |
| # Process Inputs | |
| inputs = processor( | |
| audios=[audio_path], | |
| descriptions=[text_prompt] | |
| ).to(device) | |
| # Inference | |
| with torch.no_grad(): | |
| result = model.separate(inputs) | |
| # Extract Outputs | |
| target_audio = result.target[0] # The sound you asked for | |
| residual_audio = result.residual[0] # Everything else | |
| # Get sampling rate from the processor config | |
| sr = processor.feature_extractor.sampling_rate | |
| # Save to files | |
| target_path = save_audio(target_audio, sr) | |
| residual_path = save_audio(residual_audio, sr) | |
| return target_path, residual_path | |
| # Build Gradio Interface | |
| with gr.Blocks(title="SAM-Audio Demo") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎵 SAM-Audio: Segment Anything for Audio | |
| Isolate specific sounds from an audio file using natural language prompts. | |
| **Model:** `facebook/sam-audio-small` | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_audio = gr.Audio(label="Upload Input Audio", type="filepath") | |
| text_prompt = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="e.g., 'dog barking', 'man speaking', 'typing keyboard'", | |
| info="Describe the sound you want to isolate." | |
| ) | |
| run_btn = gr.Button("Separate Audio", variant="primary") | |
| with gr.Column(): | |
| output_target = gr.Audio(label="Isolated Sound (Target)") | |
| output_residual = gr.Audio(label="Background (Residual)") | |
| run_btn.click( | |
| fn=separate_audio, | |
| inputs=[input_audio, text_prompt], | |
| outputs=[output_target, output_residual] | |
| ) | |
| # Launch | |
| demo.queue().launch() | |