import gradio as gr import torch import torchaudio import numpy as np import os os.environ["OMP_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" import tempfile import spaces import torch.nn.functional as F from gradio.themes import Soft from gradio.themes.utils import colors, fonts # --- Custom Theme Configuration --- class MidnightTheme(Soft): def __init__(self): super().__init__( primary_hue=colors.indigo, neutral_hue=colors.slate, font=(fonts.GoogleFont("Outfit"), "Arial", "sans-serif"), ) super().set( body_background_fill="#030617", block_background_fill="#10172b", block_border_color="#20293c", body_text_color="#6366f1", button_primary_background_fill="#5248e9", input_background_fill="#030617", ) midnight_theme = MidnightTheme() # --- Model Loading --- try: from sam_audio import SAMAudio, SAMAudioProcessor except ImportError: print("Warning: 'sam_audio' library not found.") MODEL_ID = "facebook/sam-audio-large" CACHE_DIR = os.path.abspath("hf_models") try: # Load model once on CPU initially model = SAMAudio.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR).eval() processor = SAMAudioProcessor.from_pretrained(MODEL_ID) except Exception as e: print(f"❌ Error loading SAM-Audio: {e}") model, processor = None, None # --- Optimized Audio Utilities --- def load_audio(file_path): waveform, sample_rate = torchaudio.load(file_path) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) return waveform, sample_rate def merge_chunks(chunks, sample_rate, overlap_duration=2.0): if len(chunks) == 1: return chunks[0] if chunks[0].dim() == 2 else chunks[0].unsqueeze(0) overlap_samples = int(overlap_duration * sample_rate) result = chunks[0] for i in range(1, len(chunks)): next_chunk = chunks[i] actual_overlap = min(overlap_samples, result.shape[1], next_chunk.shape[1]) fade_out = torch.linspace(1.0, 0.0, actual_overlap) fade_in = torch.linspace(0.0, 1.0, actual_overlap) crossfaded = result[:, -actual_overlap:] * fade_out + next_chunk[:, :actual_overlap] * fade_in result = torch.cat([result[:, :-actual_overlap], crossfaded, next_chunk[:, actual_overlap:]], dim=1) return result def save_audio(tensor, sample_rate): with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: torchaudio.save(tmp.name, tensor.cpu(), sample_rate) return tmp.name # --- Optimized Processing Core --- def separate_step(waveform, sample_rate, prompt, device, chunk_dur=30.0): """ Processes a single prompt and returns (target_tensor, residual_tensor). Includes padding logic to prevent 'tuple index out of range' on the last chunk. """ chunk_samples = int(chunk_dur * sample_rate) total_samples = waveform.shape[1] target_chunks = [] residual_chunks = [] for start in range(0, total_samples, chunk_samples): end = min(start + chunk_samples, total_samples) chunk = waveform[:, start:end] original_length = chunk.shape[1] # --- FIX: Pad chunk if it is shorter than chunk_samples --- if original_length < chunk_samples: pad_amount = chunk_samples - original_length # F.pad format for 2D input: (padding_left, padding_right) chunk = F.pad(chunk, (0, pad_amount)) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: torchaudio.save(tmp.name, chunk, sample_rate) chunk_path = tmp.name try: inputs = processor(audios=[chunk_path], descriptions=[prompt]).to(device) with torch.inference_mode(): # Reranking candidates=1 usually returns a list/tuple res = model.separate(inputs, predict_spans=False, reranking_candidates=1) # Get outputs tgt = res.target[0].cpu() rem = res.residual[0].cpu() # --- FIX: Trim padding back off --- if original_length < chunk_samples: tgt = tgt[:, :original_length] rem = rem[:, :original_length] target_chunks.append(tgt) residual_chunks.append(rem) finally: if os.path.exists(chunk_path): os.unlink(chunk_path) return merge_chunks(target_chunks, sample_rate), merge_chunks(residual_chunks, sample_rate) @spaces.GPU(duration=240) def process_dynamic_stems(file_path, selected_prompts, chunk_dur, progress=gr.Progress()): if not file_path or not selected_prompts: return [None]*4 + ["❌ Please upload audio and select at least one prompt."] if len(selected_prompts) > 3: return [None]*4 + ["❌ Maximum 3 selections allowed."] device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) try: current_waveform, sample_rate = load_audio(file_path) outputs = [None, None, None, None] for i, prompt in enumerate(selected_prompts): progress((i+1)/(len(selected_prompts)+1), desc=f"Extracting {prompt}...") target, residual = separate_step(current_waveform, sample_rate, prompt, device, chunk_dur) outputs[i] = save_audio(target, sample_rate) current_waveform = residual outputs[3] = save_audio(current_waveform, sample_rate) return outputs[0], outputs[1], outputs[2], outputs[3], f"✅ Successfully isolated: {', '.join(selected_prompts)}" except Exception as e: return [None]*4 + [f"❌ Error: {str(e)}"] finally: model.to("cpu") torch.cuda.empty_cache() # --- Gradio UI --- with gr.Blocks(theme=midnight_theme) as demo: gr.Markdown("# 🎙️ SAM-Audio Multi-Stem Splitter") with gr.Row(): with gr.Column(): input_audio = gr.Audio(label="Upload Audio", type="filepath") prompt_selection = gr.Dropdown( choices=["Vocals", "Drums", "Kick", "Snare", "Bass", "Electric Guitar", "Piano", "Violin", "Saxophone"], value=["Vocals", "Drums", "Bass"], multiselect=True, max_choices=3, allow_custom_value=True, label="Select/Type up to 3 Target Sounds", info="Separate custom types with Enter or comma." ) chunk_slider = gr.Slider(10, 60, value=30, step=5, label="Processing Chunk Duration (s)") run_btn = gr.Button("🚀 Start Separation", variant="primary") with gr.Column(): out1 = gr.Audio(label="Target 1") out2 = gr.Audio(label="Target 2") out3 = gr.Audio(label="Target 3") out_other = gr.Audio(label="Remainder (Other)") status = gr.Textbox(label="Status", interactive=False) run_btn.click( fn=process_dynamic_stems, inputs=[input_audio, prompt_selection, chunk_slider], outputs=[out1, out2, out3, out_other, status] ) if __name__ == "__main__": demo.launch()