Peter Shi commited on
Commit
f4c6545
·
1 Parent(s): 8752ef6

fix: Fixed the issue in the `merge_chunks_with_crossfade` function handling one-dimensional audio blocks and blocks shorter than the overlap area, and removed redundant dimension expansion operations in `save_audio`.v

Browse files
Files changed (1) hide show
  1. app.py +33 -12
app.py CHANGED
@@ -79,31 +79,51 @@ def split_audio_into_chunks(waveform, sample_rate, chunk_duration, overlap_durat
79
  def merge_chunks_with_crossfade(chunks, sample_rate, overlap_duration):
80
  """Merge audio chunks with crossfade on overlapping regions."""
81
  if len(chunks) == 1:
82
- return chunks[0]
 
 
 
 
83
 
84
  overlap_samples = int(overlap_duration * sample_rate)
85
- result = chunks[0]
86
 
87
- for i in range(1, len(chunks)):
 
 
 
 
 
 
 
 
 
88
  prev_chunk = result
89
- next_chunk = chunks[i]
 
 
 
 
 
 
 
 
90
 
91
  # Create fade curves
92
- fade_out = torch.linspace(1.0, 0.0, overlap_samples).to(prev_chunk.device)
93
- fade_in = torch.linspace(0.0, 1.0, overlap_samples).to(next_chunk.device)
94
 
95
  # Get overlapping regions
96
- prev_overlap = prev_chunk[:, -overlap_samples:]
97
- next_overlap = next_chunk[:, :overlap_samples]
98
 
99
  # Crossfade mix
100
  crossfaded = prev_overlap * fade_out + next_overlap * fade_in
101
 
102
  # Concatenate: non-overlap of prev + crossfaded + non-overlap of next
103
  result = torch.cat([
104
- prev_chunk[:, :-overlap_samples],
105
  crossfaded,
106
- next_chunk[:, overlap_samples:]
107
  ], dim=1)
108
 
109
  return result
@@ -168,8 +188,9 @@ def separate_audio(model_name, file_path, text_prompt, chunk_duration=DEFAULT_CH
168
  residual_merged = merge_chunks_with_crossfade(residual_chunks, sample_rate, OVERLAP_DURATION)
169
 
170
  progress(0.95, desc="Saving results...")
171
- target_path = save_audio(target_merged.unsqueeze(0), sample_rate)
172
- residual_path = save_audio(residual_merged.unsqueeze(0), sample_rate)
 
173
 
174
  progress(1.0, desc="Done!")
175
  return target_path, residual_path, f"✅ Isolated '{text_prompt}' using {model_name} ({num_chunks} chunks)"
 
79
  def merge_chunks_with_crossfade(chunks, sample_rate, overlap_duration):
80
  """Merge audio chunks with crossfade on overlapping regions."""
81
  if len(chunks) == 1:
82
+ chunk = chunks[0]
83
+ # Ensure 2D tensor
84
+ if chunk.dim() == 1:
85
+ chunk = chunk.unsqueeze(0)
86
+ return chunk
87
 
88
  overlap_samples = int(overlap_duration * sample_rate)
 
89
 
90
+ # Ensure all chunks are 2D [channels, samples]
91
+ processed_chunks = []
92
+ for chunk in chunks:
93
+ if chunk.dim() == 1:
94
+ chunk = chunk.unsqueeze(0)
95
+ processed_chunks.append(chunk)
96
+
97
+ result = processed_chunks[0]
98
+
99
+ for i in range(1, len(processed_chunks)):
100
  prev_chunk = result
101
+ next_chunk = processed_chunks[i]
102
+
103
+ # Handle case where chunks are shorter than overlap
104
+ actual_overlap = min(overlap_samples, prev_chunk.shape[1], next_chunk.shape[1])
105
+
106
+ if actual_overlap <= 0:
107
+ # No overlap possible, just concatenate
108
+ result = torch.cat([prev_chunk, next_chunk], dim=1)
109
+ continue
110
 
111
  # Create fade curves
112
+ fade_out = torch.linspace(1.0, 0.0, actual_overlap).to(prev_chunk.device)
113
+ fade_in = torch.linspace(0.0, 1.0, actual_overlap).to(next_chunk.device)
114
 
115
  # Get overlapping regions
116
+ prev_overlap = prev_chunk[:, -actual_overlap:]
117
+ next_overlap = next_chunk[:, :actual_overlap]
118
 
119
  # Crossfade mix
120
  crossfaded = prev_overlap * fade_out + next_overlap * fade_in
121
 
122
  # Concatenate: non-overlap of prev + crossfaded + non-overlap of next
123
  result = torch.cat([
124
+ prev_chunk[:, :-actual_overlap],
125
  crossfaded,
126
+ next_chunk[:, actual_overlap:]
127
  ], dim=1)
128
 
129
  return result
 
188
  residual_merged = merge_chunks_with_crossfade(residual_chunks, sample_rate, OVERLAP_DURATION)
189
 
190
  progress(0.95, desc="Saving results...")
191
+ # merged tensors are already 2D [channels, samples]
192
+ target_path = save_audio(target_merged, sample_rate)
193
+ residual_path = save_audio(residual_merged, sample_rate)
194
 
195
  progress(1.0, desc="Done!")
196
  return target_path, residual_path, f"✅ Isolated '{text_prompt}' using {model_name} ({num_chunks} chunks)"