| import torch | |
| import torch.nn as nn | |
| from torch_audiomentations import Compose, Gain, PolarityInversion, AddColoredNoise, PitchShift, PeakNormalization, PitchShift | |
| # TODO add where I copied the code from | |
| class AUG(nn.Module): | |
| def __init__(self, prob=0.3): | |
| super().__init__() | |
| self.aug = Compose( | |
| transforms=[ | |
| AddColoredNoise(p=prob), | |
| PitchShift(sample_rate=16000, min_transpose_semitones=-1, max_transpose_semitones=1, p=prob), | |
| PeakNormalization(p=0.1), | |
| Gain(min_gain_in_db=-6, max_gain_in_db=6, p=prob), | |
| ]) | |
| def forward(self, x): | |
| return self.aug(x, sample_rate=16000) | |