admin commited on
Commit
babad42
Β·
1 Parent(s): db92ea3
Files changed (1) hide show
  1. README.md +72 -0
README.md CHANGED
@@ -7,8 +7,80 @@ The demucs model in the ICASSP 2024 Cadenza Challenge is an innovative sound sep
7
 
8
  ## Usage
9
  ```python
 
 
 
 
 
10
  from modelscope import snapshot_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  model_dir = snapshot_download('monetjoe/hdemucs_high_musdbhq')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ```
13
 
14
  ## Maintenance
 
7
 
8
  ## Usage
9
  ```python
10
+ import torch
11
+ import torchaudio
12
+ from typing import Callable
13
+ from functools import partial
14
+ from dataclasses import dataclass
15
  from modelscope import snapshot_download
16
+ from torchaudio.models import hdemucs_high
17
+
18
+ @dataclass
19
+ class SourceSeparationBundle:
20
+ """Dataclass that bundles components for performing source separation.
21
+
22
+ Example
23
+ >>> import torchaudio
24
+ >>> from torchaudio.pipelines import CONVTASNET_BASE_LIBRI2MIX
25
+ >>> import torch
26
+ >>>
27
+ >>> # Build the separation model.
28
+ >>> model = CONVTASNET_BASE_LIBRI2MIX.get_model()
29
+ >>> 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ|19.1M/19.1M [00:04<00:00, 4.93MB/s]
30
+ >>>
31
+ >>> # Instantiate the test set of Libri2Mix dataset.
32
+ >>> dataset = torchaudio.datasets.LibriMix("/home/datasets/", subset="test")
33
+ >>>
34
+ >>> # Apply source separation on mixture audio.
35
+ >>> for i, data in enumerate(dataset):
36
+ >>> sample_rate, mixture, clean_sources = data
37
+ >>> # Make sure the shape of input suits the model requirement.
38
+ >>> mixture = mixture.reshape(1, 1, -1)
39
+ >>> estimated_sources = model(mixture)
40
+ >>> score = si_snr_pit(estimated_sources, clean_sources) # for demonstration
41
+ >>> print(f"Si-SNR score is : {score}.)
42
+ >>> break
43
+ >>> Si-SNR score is : 16.24.
44
+ >>>
45
+ """
46
+
47
+ _model_path: str
48
+ _model_factory_func: Callable[[], torch.nn.Module]
49
+ _sample_rate: int
50
+
51
+ @property
52
+ def sample_rate(self) -> int:
53
+ """Sample rate of the audio that the model is trained on.
54
+
55
+ :type: int
56
+ """
57
+ return self._sample_rate
58
+
59
+ def get_model(self) -> torch.nn.Module:
60
+ """Construct the model and load the pretrained weight."""
61
+ model = self._model_factory_func()
62
+ path = torchaudio.utils.download_asset(self._model_path)
63
+ state_dict = torch.load(path)
64
+ model.load_state_dict(state_dict)
65
+ model.eval()
66
+ return model
67
+
68
  model_dir = snapshot_download('monetjoe/hdemucs_high_musdbhq')
69
+ HDEMUCS_HIGH_MUSDB = SourceSeparationBundle(
70
+ _model_path=f"{model_dir}/hdemucs_high_musdbhq_only.pt",
71
+ _model_factory_func=partial(
72
+ hdemucs_high, sources=["drums", "bass", "other", "vocals"]
73
+ ),
74
+ _sample_rate=44100,
75
+ )
76
+ HDEMUCS_HIGH_MUSDB.__doc__ = """Pre-trained music source separation pipeline with
77
+ *Hybrid Demucs* :cite:`defossez2021hybrid` trained on the training set of MUSDB-HQ :cite:`MUSDB18HQ`.
78
+
79
+ The model is constructed by :func:`~torchaudio.models.hdemucs_high`.
80
+ Training was performed in the original HDemucs repository `here <https://github.com/facebookresearch/demucs/>`__.
81
+
82
+ Please refer to :class:`SourceSeparationBundle` for usage instructions.
83
+ """
84
  ```
85
 
86
  ## Maintenance