Peter Shi commited on
Commit
c12dff8
·
1 Parent(s): 42e2df6

Add example with Git LFS for MP4

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. app.py +114 -20
  3. examples/office.mp4 +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -4,19 +4,51 @@ import torch
4
  import torchaudio
5
  import tempfile
6
  import warnings
 
7
  warnings.filterwarnings("ignore")
8
 
9
  from sam_audio import SAMAudio, SAMAudioProcessor
10
 
11
- # Configuration
12
- MODEL_NAME = "facebook/sam-audio-small"
 
 
 
 
 
 
 
13
 
14
- # Load model and processor
15
- print(f"Loading {MODEL_NAME}...")
 
 
 
 
 
 
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- model = SAMAudio.from_pretrained(MODEL_NAME).to(device).eval()
18
- processor = SAMAudioProcessor.from_pretrained(MODEL_NAME)
19
- print(f"Model loaded on {device}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def save_audio(tensor, sample_rate):
22
  """Helper to save torch tensor to a temp file for Gradio output."""
@@ -25,9 +57,11 @@ def save_audio(tensor, sample_rate):
25
  return tmp.name
26
 
27
  @spaces.GPU(duration=300)
28
- def separate_audio(audio_path, video_path, text_prompt):
29
- # Determine which input to use
30
- file_path = video_path if video_path else audio_path
 
 
31
 
32
  if not file_path:
33
  return None, None, "❌ Please upload an audio or video file."
@@ -48,14 +82,27 @@ def separate_audio(audio_path, video_path, text_prompt):
48
  target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate)
49
  residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate)
50
 
51
- return target_path, residual_path, f"✅ Successfully isolated '{text_prompt}'"
52
 
53
  except Exception as e:
54
  import traceback
55
  traceback.print_exc()
56
  return None, None, f"❌ Error: {str(e)}"
57
 
58
- # Build Gradio Interface - Simple and clean
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  with gr.Blocks(title="SAM-Audio Demo") as demo:
60
  gr.Markdown(
61
  """
@@ -63,15 +110,23 @@ with gr.Blocks(title="SAM-Audio Demo") as demo:
63
 
64
  Isolate specific sounds from an audio or video file using natural language prompts.
65
 
66
- **Model:** [facebook/sam-audio-small](https://huggingface.co/facebook/sam-audio-small)
67
  """
68
  )
69
 
70
  with gr.Row():
71
  with gr.Column():
72
- gr.Markdown("### Upload Audio or Video")
73
- input_audio = gr.Audio(label="Audio File", type="filepath")
74
- input_video = gr.Video(label="Video File (MP4)")
 
 
 
 
 
 
 
 
75
 
76
  text_prompt = gr.Textbox(
77
  label="Text Prompt",
@@ -88,14 +143,53 @@ with gr.Blocks(title="SAM-Audio Demo") as demo:
88
  output_residual = gr.Audio(label="Background (Residual)")
89
 
90
  gr.Markdown("---")
91
- gr.Markdown("### Example Prompts")
92
- gr.Markdown("- A man speaking\n- A woman singing\n- Piano\n- Drums\n- Guitar\n- Dog barking\n- Car engine")
 
 
 
 
 
 
 
 
 
93
 
 
94
  run_btn.click(
95
- fn=separate_audio,
96
- inputs=[input_audio, input_video, text_prompt],
97
  outputs=[output_target, output_residual, status_output]
98
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  if __name__ == "__main__":
101
  demo.launch()
 
4
  import torchaudio
5
  import tempfile
6
  import warnings
7
+ import os
8
  warnings.filterwarnings("ignore")
9
 
10
  from sam_audio import SAMAudio, SAMAudioProcessor
11
 
12
+ # Available models
13
+ MODELS = {
14
+ "sam-audio-small": "facebook/sam-audio-small",
15
+ "sam-audio-base": "facebook/sam-audio-base",
16
+ "sam-audio-large": "facebook/sam-audio-large",
17
+ "sam-audio-small-tv (Visual)": "facebook/sam-audio-small-tv",
18
+ "sam-audio-base-tv (Visual)": "facebook/sam-audio-base-tv",
19
+ "sam-audio-large-tv (Visual)": "facebook/sam-audio-large-tv",
20
+ }
21
 
22
+ # Default model
23
+ DEFAULT_MODEL = "sam-audio-small"
24
+
25
+ # Example files
26
+ EXAMPLES_DIR = "examples"
27
+ EXAMPLE_FILE = os.path.join(EXAMPLES_DIR, "office.mp4")
28
+
29
+ # Global model cache
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ current_model_name = None
32
+ model = None
33
+ processor = None
34
+
35
+ def load_model(model_name):
36
+ """Load or switch model."""
37
+ global current_model_name, model, processor
38
+
39
+ model_id = MODELS.get(model_name, MODELS[DEFAULT_MODEL])
40
+
41
+ if current_model_name == model_name and model is not None:
42
+ return
43
+
44
+ print(f"Loading {model_id}...")
45
+ model = SAMAudio.from_pretrained(model_id).to(device).eval()
46
+ processor = SAMAudioProcessor.from_pretrained(model_id)
47
+ current_model_name = model_name
48
+ print(f"Model {model_id} loaded on {device}.")
49
+
50
+ # Load default model at startup
51
+ load_model(DEFAULT_MODEL)
52
 
53
  def save_audio(tensor, sample_rate):
54
  """Helper to save torch tensor to a temp file for Gradio output."""
 
57
  return tmp.name
58
 
59
  @spaces.GPU(duration=300)
60
+ def separate_audio(model_name, file_path, text_prompt):
61
+ global model, processor
62
+
63
+ # Load selected model if different
64
+ load_model(model_name)
65
 
66
  if not file_path:
67
  return None, None, "❌ Please upload an audio or video file."
 
82
  target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate)
83
  residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate)
84
 
85
+ return target_path, residual_path, f"✅ Successfully isolated '{text_prompt}' using {model_name}"
86
 
87
  except Exception as e:
88
  import traceback
89
  traceback.print_exc()
90
  return None, None, f"❌ Error: {str(e)}"
91
 
92
+ def process_file(model_name, file, prompt):
93
+ if file is None:
94
+ return None, None, "❌ Please upload a file."
95
+ # Handle both file object and file path
96
+ file_path = file.name if hasattr(file, 'name') else file
97
+ return separate_audio(model_name, file_path, prompt)
98
+
99
+ def process_example(model_name, file_path, prompt):
100
+ """Process directly from example - file_path is already a string."""
101
+ if not file_path or not os.path.exists(file_path):
102
+ return None, None, "❌ Example file not found."
103
+ return separate_audio(model_name, file_path, prompt)
104
+
105
+ # Build Gradio Interface
106
  with gr.Blocks(title="SAM-Audio Demo") as demo:
107
  gr.Markdown(
108
  """
 
110
 
111
  Isolate specific sounds from an audio or video file using natural language prompts.
112
 
113
+ **Models:** [facebook/sam-audio](https://huggingface.co/collections/facebook/sam-audio-67608edbf75ad66bf5e8cb3a)
114
  """
115
  )
116
 
117
  with gr.Row():
118
  with gr.Column():
119
+ model_selector = gr.Dropdown(
120
+ choices=list(MODELS.keys()),
121
+ value=DEFAULT_MODEL,
122
+ label="Model",
123
+ info="Larger = better quality but slower. TV variants for visual prompting."
124
+ )
125
+
126
+ input_file = gr.File(
127
+ label="Upload Audio or Video",
128
+ file_types=[".mp3", ".wav", ".flac", ".ogg", ".m4a", ".mp4", ".mkv", ".avi", ".mov", ".webm"],
129
+ )
130
 
131
  text_prompt = gr.Textbox(
132
  label="Text Prompt",
 
143
  output_residual = gr.Audio(label="Background (Residual)")
144
 
145
  gr.Markdown("---")
146
+ gr.Markdown("### 🎬 Try Demo Examples")
147
+ gr.Markdown("Click an example below to auto-fill and process:")
148
+
149
+ with gr.Row():
150
+ if os.path.exists(EXAMPLE_FILE):
151
+ example_btn1 = gr.Button("🎤 Man Speaking")
152
+ example_btn2 = gr.Button("🎤 Woman Speaking")
153
+ example_btn3 = gr.Button("🎵 Background Music")
154
+
155
+ gr.Markdown("---")
156
+ gr.Markdown("**Supported formats:** MP3, WAV, FLAC, OGG, M4A, MP4, MKV, AVI, MOV, WebM")
157
 
158
+ # Main run button
159
  run_btn.click(
160
+ fn=process_file,
161
+ inputs=[model_selector, input_file, text_prompt],
162
  outputs=[output_target, output_residual, status_output]
163
  )
164
+
165
+ # Example buttons - auto-fill and process
166
+ if os.path.exists(EXAMPLE_FILE):
167
+ example_btn1.click(
168
+ fn=lambda: (EXAMPLE_FILE, "A man speaking"),
169
+ outputs=[input_file, text_prompt]
170
+ ).then(
171
+ fn=lambda m: process_example(m, EXAMPLE_FILE, "A man speaking"),
172
+ inputs=[model_selector],
173
+ outputs=[output_target, output_residual, status_output]
174
+ )
175
+
176
+ example_btn2.click(
177
+ fn=lambda: (EXAMPLE_FILE, "A woman speaking"),
178
+ outputs=[input_file, text_prompt]
179
+ ).then(
180
+ fn=lambda m: process_example(m, EXAMPLE_FILE, "A woman speaking"),
181
+ inputs=[model_selector],
182
+ outputs=[output_target, output_residual, status_output]
183
+ )
184
+
185
+ example_btn3.click(
186
+ fn=lambda: (EXAMPLE_FILE, "Background music"),
187
+ outputs=[input_file, text_prompt]
188
+ ).then(
189
+ fn=lambda m: process_example(m, EXAMPLE_FILE, "Background music"),
190
+ inputs=[model_selector],
191
+ outputs=[output_target, output_residual, status_output]
192
+ )
193
 
194
  if __name__ == "__main__":
195
  demo.launch()
examples/office.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0f583ff34c5fd9d1a83d640e7c0131ad339755bd69e54f104723b707f213c21
3
+ size 4551702