sam-audio-webui / app.py
Peter Shi
Restore audio/video preview with tabs
cebdac8
raw
history blame
6.21 kB
import spaces
import gradio as gr
import torch
import torchaudio
import tempfile
import warnings
import os
warnings.filterwarnings("ignore")
from sam_audio import SAMAudio, SAMAudioProcessor
# Available models
MODELS = {
"sam-audio-small": "facebook/sam-audio-small",
"sam-audio-base": "facebook/sam-audio-base",
"sam-audio-large": "facebook/sam-audio-large",
"sam-audio-small-tv (Visual)": "facebook/sam-audio-small-tv",
"sam-audio-base-tv (Visual)": "facebook/sam-audio-base-tv",
"sam-audio-large-tv (Visual)": "facebook/sam-audio-large-tv",
}
DEFAULT_MODEL = "sam-audio-small"
EXAMPLES_DIR = "examples"
EXAMPLE_FILE = os.path.join(EXAMPLES_DIR, "office.mp4")
# Global model cache
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
current_model_name = None
model = None
processor = None
def load_model(model_name):
global current_model_name, model, processor
model_id = MODELS.get(model_name, MODELS[DEFAULT_MODEL])
if current_model_name == model_name and model is not None:
return
print(f"Loading {model_id}...")
model = SAMAudio.from_pretrained(model_id).to(device).eval()
processor = SAMAudioProcessor.from_pretrained(model_id)
current_model_name = model_name
print(f"Model {model_id} loaded on {device}.")
load_model(DEFAULT_MODEL)
def save_audio(tensor, sample_rate):
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
torchaudio.save(tmp.name, tensor, sample_rate)
return tmp.name
@spaces.GPU(duration=300)
def separate_audio(model_name, file_path, text_prompt):
global model, processor
load_model(model_name)
if not file_path:
return None, None, "❌ Please upload an audio or video file."
if not text_prompt or not text_prompt.strip():
return None, None, "❌ Please enter a text prompt."
try:
inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device)
with torch.inference_mode():
result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
sample_rate = processor.audio_sampling_rate
target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate)
residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate)
return target_path, residual_path, f"βœ… Isolated '{text_prompt}' using {model_name}"
except Exception as e:
import traceback
traceback.print_exc()
return None, None, f"❌ Error: {str(e)}"
def process_audio(model_name, audio_path, prompt):
if not audio_path:
return None, None, "❌ Please upload an audio file."
return separate_audio(model_name, audio_path, prompt)
def process_video(model_name, video_path, prompt):
if not video_path:
return None, None, "❌ Please upload a video file."
return separate_audio(model_name, video_path, prompt)
def process_example(model_name, prompt):
if not os.path.exists(EXAMPLE_FILE):
return None, None, "❌ Example file not found."
return separate_audio(model_name, EXAMPLE_FILE, prompt)
def load_example(prompt):
return EXAMPLE_FILE, prompt
# Build Interface
with gr.Blocks(title="SAM-Audio Demo") as demo:
gr.Markdown(
"""
# 🎡 SAM-Audio: Segment Anything for Audio
Isolate specific sounds from audio or video using natural language prompts.
"""
)
with gr.Row():
with gr.Column(scale=1):
model_selector = gr.Dropdown(
choices=list(MODELS.keys()),
value=DEFAULT_MODEL,
label="Model"
)
with gr.Tabs():
with gr.TabItem("🎡 Audio"):
input_audio = gr.Audio(label="Upload Audio", type="filepath")
with gr.TabItem("🎬 Video"):
input_video = gr.Video(label="Upload Video")
text_prompt = gr.Textbox(
label="Text Prompt",
placeholder="e.g., 'A man speaking', 'Piano', 'Dog barking'"
)
run_btn = gr.Button("🎯 Isolate Sound", variant="primary")
status_output = gr.Markdown("")
with gr.Column(scale=1):
gr.Markdown("### Results")
output_target = gr.Audio(label="Isolated Sound (Target)")
output_residual = gr.Audio(label="Background (Residual)")
gr.Markdown("---")
gr.Markdown("### 🎬 Demo Examples (click to auto-process)")
with gr.Row():
if os.path.exists(EXAMPLE_FILE):
example_btn1 = gr.Button("🎀 Man Speaking")
example_btn2 = gr.Button("🎀 Woman Speaking")
example_btn3 = gr.Button("🎡 Background Music")
# Audio processing
run_btn.click(
fn=lambda m, a, v, p: process_audio(m, a, p) if a else process_video(m, v, p),
inputs=[model_selector, input_audio, input_video, text_prompt],
outputs=[output_target, output_residual, status_output]
)
# Example buttons
if os.path.exists(EXAMPLE_FILE):
example_btn1.click(
fn=lambda: (EXAMPLE_FILE, "A man speaking"),
outputs=[input_video, text_prompt]
).then(
fn=lambda m: process_example(m, "A man speaking"),
inputs=[model_selector],
outputs=[output_target, output_residual, status_output]
)
example_btn2.click(
fn=lambda: (EXAMPLE_FILE, "A woman speaking"),
outputs=[input_video, text_prompt]
).then(
fn=lambda m: process_example(m, "A woman speaking"),
inputs=[model_selector],
outputs=[output_target, output_residual, status_output]
)
example_btn3.click(
fn=lambda: (EXAMPLE_FILE, "Background music"),
outputs=[input_video, text_prompt]
).then(
fn=lambda m: process_example(m, "Background music"),
inputs=[model_selector],
outputs=[output_target, output_residual, status_output]
)
if __name__ == "__main__":
demo.launch()