Tharya commited on
Commit
9442c34
·
verified ·
1 Parent(s): 0f9d27a

Upload 3 files

Browse files
Files changed (3) hide show
  1. mae.py +483 -0
  2. requirements.txt +5 -0
  3. test.py +103 -0
mae.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import torch.utils.checkpoint
6
+ from timm.models.swin_transformer import SwinTransformerBlock
7
+ from timm.models.vision_transformer import Block
8
+ from timm.models.layers import to_2tuple
9
+
10
+
11
+ class PatchEmbed(nn.Module):
12
+ """ Image to Patch Embedding
13
+ """
14
+
15
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
16
+ super().__init__()
17
+ img_size = to_2tuple(img_size)
18
+ patch_size = to_2tuple(patch_size)
19
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
20
+ self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
21
+ self.img_size = img_size
22
+ self.patch_size = patch_size
23
+ self.num_patches = num_patches
24
+
25
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
26
+
27
+ def forward(self, x):
28
+ B, C, H, W = x.shape
29
+ x = self.proj(x).flatten(2).transpose(1, 2)
30
+ return x
31
+
32
+
33
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
34
+ """
35
+ embed_dim: output dimension for each position
36
+ pos: a list of positions to be encoded: size (M,)
37
+ out: (M, D)
38
+ """
39
+ assert embed_dim % 2 == 0
40
+ omega = np.arange(embed_dim // 2, dtype=float)
41
+ omega /= embed_dim / 2.
42
+ omega = 1. / 10000 ** omega # (D/2,)
43
+
44
+ pos = pos.reshape(-1) # (M,)
45
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
46
+
47
+ emb_sin = np.sin(out) # (M, D/2)
48
+ emb_cos = np.cos(out) # (M, D/2)
49
+
50
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
51
+ return emb
52
+
53
+
54
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
55
+ assert embed_dim % 2 == 0
56
+ # use half of dimensions to encode grid_h
57
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
58
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
59
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
60
+ return emb
61
+
62
+
63
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
64
+ """
65
+ grid_size: int of the grid height and width
66
+ return:
67
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
68
+ """
69
+ grid_h = np.arange(grid_size, dtype=np.float32)
70
+ grid_w = np.arange(grid_size, dtype=np.float32)
71
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
72
+ grid = np.stack(grid, axis=0)
73
+
74
+ grid = grid.reshape([2, 1, grid_size, grid_size])
75
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
76
+ if cls_token:
77
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
78
+ return pos_embed
79
+
80
+
81
+ def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
82
+ """
83
+ grid_size: int of the grid height and width
84
+ return:
85
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
86
+ """
87
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
88
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
89
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
90
+ grid = np.stack(grid, axis=0)
91
+
92
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
93
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
94
+ if cls_token:
95
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
96
+ return pos_embed
97
+
98
+
99
+ class SwinTransformerBlockWrapper(torch.nn.Module):
100
+ """
101
+ Wrap SwinTransformerBlock to fit the input shape of [B, N, C] like TransformerBlock.
102
+
103
+ The SwinTransformerBlock takes the input shape of [B, H, W, C], and TransformerBlock
104
+ takes the input shape of [B, N, C].
105
+ """
106
+
107
+ def __init__(self, block: SwinTransformerBlock):
108
+ super().__init__()
109
+ self.block = block
110
+ self.input_resolution = block.input_resolution
111
+
112
+ def forward(self, x):
113
+ """
114
+ :param x: [B, N, C]
115
+ :return: [B, N, C]
116
+ """
117
+ B, N, C = x.shape
118
+ x = x.reshape(B, *self.input_resolution, C)
119
+ x = self.block(x)
120
+ x = x.reshape(B, N, C)
121
+ return x
122
+
123
+
124
+ class MaskedAutoencoderViT(nn.Module):
125
+ """ Masked Autoencoder with VisionTransformer backbone
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ img_size=224,
131
+ patch_size=16,
132
+ in_chans=3, # input channels. 1 for audio, 3 for image
133
+ embed_dim=1024,
134
+ depth=24, # transformer depth
135
+ num_heads=16,
136
+ decoder_mode=0, # 0: transformer (global attn), 1: swin-transformer (swined local attn)
137
+ no_shift=False, # invalid when decoder_mode=0. swin-transformer. shift patch or not
138
+ decoder_embed_dim=512,
139
+ decoder_depth=8, # invalid when decoder_mode=1. It will be fixed to 16 when decoder_mode=1.
140
+ decoder_num_heads=16, # invalid when decoder_mode=1. It will be fixed to 16 when decoder_mode=1.
141
+ mlp_ratio=4., # hidden dimension / embed dimension in feedforward layer of transformer
142
+ norm_layer=nn.LayerNorm,
143
+ norm_pix_loss=False, # use (per-patch) normalized pixels as targets for computing loss
144
+ pos_trainable=False,
145
+ ):
146
+ super().__init__()
147
+
148
+ self.img_size = to_2tuple(img_size)
149
+
150
+ self.embed_dim = embed_dim
151
+ self.decoder_embed_dim = decoder_embed_dim
152
+ # MAE encoder specifics
153
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
154
+ num_patches = self.patch_embed.num_patches
155
+
156
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
157
+
158
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim),
159
+ requires_grad=pos_trainable) # fixed sin-cos embedding
160
+
161
+ self.encoder_depth = depth
162
+ self.blocks = nn.ModuleList([
163
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(depth)])
164
+ self.norm = norm_layer(embed_dim)
165
+
166
+ # --------------------------------------------------------------------------
167
+ # MAE decoder specifics
168
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
169
+
170
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
171
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim),
172
+ requires_grad=pos_trainable) # fixed sin-cos embedding
173
+
174
+ self.no_shift = no_shift
175
+
176
+ self.decoder_mode = decoder_mode
177
+
178
+ window_size = (4, 4)
179
+ feat_size = (self.img_size[0] // patch_size, 8)
180
+
181
+ if self.decoder_mode == 1:
182
+ decoder_modules = []
183
+ for index in range(16):
184
+ if self.no_shift:
185
+ shift_size = (0, 0)
186
+ else:
187
+ if (index % 2) == 0:
188
+ shift_size = (0, 0)
189
+ else:
190
+ shift_size = (2, 0)
191
+ # shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size])
192
+ decoder_modules.append(
193
+ SwinTransformerBlockWrapper(
194
+ SwinTransformerBlock(
195
+ dim=decoder_embed_dim,
196
+ input_resolution=feat_size,
197
+ num_heads=16,
198
+ window_size=window_size,
199
+ shift_size=shift_size,
200
+ mlp_ratio=mlp_ratio,
201
+ proj_drop=0.0,
202
+ attn_drop=0.0,
203
+ drop_path=0.0,
204
+ norm_layer=norm_layer,
205
+ )
206
+ )
207
+ )
208
+ self.decoder_blocks = nn.ModuleList(decoder_modules)
209
+ else:
210
+ # Transformer
211
+ self.decoder_blocks = nn.ModuleList([
212
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
213
+ for _ in range(decoder_depth)])
214
+
215
+ self.decoder_norm = norm_layer(decoder_embed_dim)
216
+ self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch
217
+
218
+ self.norm_pix_loss = norm_pix_loss
219
+
220
+ self.patch_size = patch_size
221
+
222
+ self.initialize_weights()
223
+
224
+ def initialize_weights(self):
225
+ # initialize (and freeze) pos_embed by sin-cos embedding
226
+ pos_embed = get_2d_sincos_pos_embed_flexible(self.pos_embed.shape[-1], self.patch_embed.patch_hw,
227
+ cls_token=True)
228
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
229
+
230
+ decoder_pos_embed = get_2d_sincos_pos_embed_flexible(self.decoder_pos_embed.shape[-1],
231
+ self.patch_embed.patch_hw, cls_token=True)
232
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
233
+
234
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
235
+ w = self.patch_embed.proj.weight.data
236
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
237
+
238
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
239
+ torch.nn.init.normal_(self.cls_token, std=.02)
240
+ torch.nn.init.normal_(self.mask_token, std=.02)
241
+
242
+ # initialize nn.Linear and nn.LayerNorm
243
+ self.apply(self._init_weights)
244
+
245
+ def _init_weights(self, m):
246
+ if isinstance(m, nn.Linear):
247
+ # we use xavier_uniform following official JAX ViT:
248
+ torch.nn.init.xavier_uniform_(m.weight)
249
+ if isinstance(m, nn.Linear) and m.bias is not None:
250
+ nn.init.constant_(m.bias, 0)
251
+ elif isinstance(m, nn.LayerNorm):
252
+ nn.init.constant_(m.bias, 0)
253
+ nn.init.constant_(m.weight, 1.0)
254
+
255
+ def patchify(self, imgs):
256
+ """
257
+ imgs: (N, 3, H, W)
258
+ x: (N, L, patch_size**2 *3)
259
+ L = (H/p)*(W/p)
260
+ """
261
+ p = self.patch_embed.patch_size[0]
262
+
263
+ h = imgs.shape[2] // p
264
+ w = imgs.shape[3] // p
265
+ # h,w = self.patch_embed.patch_hw
266
+ x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
267
+ x = torch.einsum('nchpwq->nhwpqc', x)
268
+ x = x.reshape(imgs.shape[0], h * w, p ** 2 * 1)
269
+
270
+ return x
271
+
272
+ def unpatchify(self, x):
273
+ """
274
+ x: (N, L, patch_size**2 *3)
275
+ specs: (N, 1, H, W)
276
+ """
277
+ p = self.patch_embed.patch_size[0]
278
+ h = self.img_size[0] // p
279
+ w = 128 // p
280
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 1))
281
+ x = torch.einsum('nhwpqc->nchpwq', x)
282
+ specs = x.reshape(x.shape[0], 1, h * p, w * p)
283
+ return specs
284
+
285
+ def random_masking(self, x, mask_ratio):
286
+ """
287
+ Perform per-sample random masking by per-sample shuffling.
288
+ Per-sample shuffling is done by argsort random noise.
289
+ x: [N, L, D], sequence
290
+ """
291
+ N, L, D = x.shape # batch, length, dim
292
+ len_keep = int(L * (1 - mask_ratio))
293
+
294
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
295
+
296
+ # sort noise for each sample
297
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
298
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
299
+
300
+ # keep the first subset
301
+ ids_keep = ids_shuffle[:, :len_keep]
302
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
303
+
304
+ # generate the binary mask: 0 is keep, 1 is remove
305
+ mask = torch.ones([N, L], device=x.device)
306
+ mask[:, :len_keep] = 0
307
+ # unshuffle to get the binary mask
308
+ mask = torch.gather(mask, dim=1, index=ids_restore)
309
+
310
+ return x_masked, mask, ids_restore
311
+
312
+ def forward_encoder(self, x, mask_ratio):
313
+ """
314
+ :param x: [N, C, H, W]
315
+ :param mask_ratio: float. ratio of masked patches
316
+ :return: tuple. x: [N, L', D], mask: [N, L], ids_restore: [N, L], None
317
+ """
318
+ # embed patches
319
+ x = self.patch_embed(x)
320
+
321
+ B, L, D = x.shape
322
+
323
+ # add pos embed w/o cls token
324
+ x = x + self.pos_embed[:, 1:L + 1, :]
325
+
326
+ # masking: length -> length * mask_ratio
327
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
328
+
329
+ # append cls token
330
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
331
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
332
+ x = torch.cat((cls_tokens, x), dim=1)
333
+
334
+ # apply Transformer blocks
335
+ for blk in self.blocks:
336
+ x = blk(x)
337
+ x = self.norm(x)
338
+
339
+ return x, mask, ids_restore
340
+
341
+ def forward_encoder_no_mask(
342
+ self,
343
+ x,
344
+ header='mean'
345
+ ):
346
+ """
347
+ :param x: [N, C, H, W]
348
+ :param header: str. 'cls' or 'mean'
349
+ :param key_padding_mask: [N, L], 0 is keep, 1 is remove
350
+ :return: contextual_emb: [N, L, D]
351
+ """
352
+ # embed patches
353
+ x = self.patch_embed(x)
354
+
355
+ B, L, D = x.shape
356
+
357
+ # add pos embed w/o cls token
358
+ x = x + self.pos_embed[:, 1:L + 1, :]
359
+
360
+ # append cls token
361
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
362
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
363
+ x = torch.cat((cls_tokens, x), dim=1)
364
+
365
+ # apply Transformer blocks
366
+ for n, blk in enumerate(self.blocks):
367
+ x = blk(x)
368
+
369
+ x = self.norm(x)
370
+
371
+ if header == 'cls':
372
+ emb = x[:, 0, :]
373
+ elif header == 'mean':
374
+ emb = x[:, 1:, :].mean(dim=1)
375
+ else:
376
+ raise NotImplementedError
377
+
378
+ return emb
379
+
380
+ def forward_decoder(self, x, ids_restore):
381
+ """
382
+ :param x: [N, L, D]
383
+ :param ids_restore: [N, L]
384
+ :return: pred: [N, L, p*p*3], None, None
385
+ """
386
+ # embed tokens
387
+ x = self.decoder_embed(x) # [N, L, D] -> [N, L, D']
388
+
389
+ # append mask tokens to sequence
390
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
391
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
392
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
393
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
394
+
395
+ B, L, D = x.shape
396
+
397
+ # add pos embed
398
+ x = x + self.decoder_pos_embed[:, :L, :]
399
+
400
+ if self.decoder_mode != 0:
401
+ B, L, D = x.shape
402
+ x = x[:, 1:, :]
403
+
404
+ if self.decoder_mode > 3: # mvit
405
+ x = self.decoder_blocks(x)
406
+ else:
407
+ # apply Transformer blocks
408
+ for blk in self.decoder_blocks:
409
+ x = blk(x)
410
+
411
+ x = self.decoder_norm(x)
412
+
413
+ # predictor projection
414
+ pred = self.decoder_pred(x)
415
+
416
+ # remove cls token
417
+ if self.decoder_mode == 0:
418
+ pred = pred[:, 1:, :]
419
+
420
+ return pred
421
+
422
+ def forward_loss(self, imgs, pred, mask, norm_pix_loss=False):
423
+ """
424
+ imgs: [N, 3, H, W]
425
+ pred: [N, L, p*p*3]
426
+ mask: [N, L], 0 is keep, 1 is remove,
427
+ """
428
+ target = self.patchify(imgs)
429
+ if norm_pix_loss:
430
+ mean = target.mean(dim=-1, keepdim=True)
431
+ var = target.var(dim=-1, keepdim=True)
432
+ target = (target - mean) / (var + 1.e-6) ** .5
433
+
434
+ loss = (pred - target) ** 2
435
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
436
+
437
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
438
+ return loss
439
+
440
+ def forward(self, imgs, mask_ratio=0.8):
441
+ """
442
+
443
+ :param imgs: [N, C, H, W]
444
+ :param mask_ratio: float. ratio of masked patches
445
+ :return: tuple. loss_recon: float, pred: [N, L, p*p*3], mask: [N, L], None
446
+ """
447
+ emb_enc, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
448
+ pred = self.forward_decoder(emb_enc, ids_restore) # [N, L, p*p*3]
449
+ loss_recon = self.forward_loss(imgs, pred, mask, norm_pix_loss=self.norm_pix_loss)
450
+ return loss_recon, pred, mask
451
+
452
+
453
+ if __name__ == '__main__':
454
+ device = 'cpu'
455
+ # device = 'cuda'
456
+
457
+ # Model
458
+ audio_mae = MaskedAutoencoderViT(
459
+ img_size=(2048, 128),
460
+ patch_size=16,
461
+ in_chans=1,
462
+ embed_dim=768,
463
+ depth=12,
464
+ num_heads=12,
465
+ decoder_mode=1,
466
+ no_shift=False,
467
+ decoder_embed_dim=512,
468
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
469
+ norm_pix_loss=False,
470
+ pos_trainable=False,
471
+ )
472
+
473
+ # Load pre-trained weights
474
+ ckpt_path = 'music-mae-32kHz.pth'
475
+ audio_mae.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
476
+ audio_mae.to(device)
477
+
478
+ # Generate a batch of random inputs: (N, C, H, W), N=4 (batch size), C=1 (channel), H=2048, W=128
479
+ # Each input is a mel spectrogram with shape (2048, 128)
480
+ x = torch.randn(4, 1, 2048, 128).to(device)
481
+
482
+ # Compute the representation of the input batch
483
+ emb = audio_mae.forward_encoder_no_mask(x, header='mean') # torch.Size([4, 768])
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==2.1.1
2
+ timm==0.9.12
3
+ numpy==1.24.4
4
+ librosa==0.10.1
5
+ miniaudio==1.59
test.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Tuple
2
+ import numpy as np
3
+ from numpy.typing import NDArray
4
+ import torch
5
+ from torch import nn
6
+ from functools import partial
7
+ import matplotlib.pyplot as plt
8
+ from PIL import Image
9
+ import librosa
10
+ import miniaudio
11
+
12
+ from mae import MaskedAutoencoderViT
13
+
14
+
15
+ def load_audio(
16
+ path: str,
17
+ sr: int = 32000,
18
+ duration: int = 20,
19
+ ) -> (np.ndarray, int):
20
+ g = miniaudio.stream_file(path, output_format=miniaudio.SampleFormat.FLOAT32, nchannels=1,
21
+ sample_rate=sr, frames_to_read=sr * duration)
22
+ signal = np.array(next(g))
23
+ return signal
24
+
25
+
26
+ def mel_spectrogram(
27
+ signal: np.ndarray,
28
+ sr: int = 32000,
29
+ n_fft: int = 800,
30
+ hop_length: int = 320,
31
+ n_mels: int = 128,
32
+ ) -> np.ndarray:
33
+ mel_spec = librosa.feature.melspectrogram(
34
+ y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
35
+ window='hann', pad_mode='constant'
36
+ )
37
+ mel_spec = librosa.power_to_db(mel_spec) # (freq, time)
38
+ return mel_spec.T # (time, freq)
39
+
40
+
41
+ def display_image(
42
+ img: Union[NDArray, Image.Image],
43
+ figsize: Tuple[float, float] = (5, 5),
44
+ ) -> None:
45
+ plt.figure(figsize=figsize)
46
+ plt.imshow(img, origin='lower', aspect='auto') # cmp = 'viridis', 'coolwarm'
47
+ plt.axis('off')
48
+ plt.colorbar()
49
+ plt.tight_layout()
50
+ plt.show()
51
+
52
+
53
+ def normalize(arr: np.ndarray, eps: float = 1e-8) -> np.ndarray:
54
+ return (arr - arr.mean()) / (arr.std() + eps)
55
+
56
+
57
+ if __name__ == '__main__':
58
+ mp3_file = "/Users/chenjing22/Downloads/songs/See You Again.mp3"
59
+ mel_spec = mel_spectrogram(load_audio(mp3_file, duration=21)) # (time, freq)
60
+
61
+ # padding or truncating
62
+ length = mel_spec.shape[0]
63
+ target_length = 2048
64
+ mel_spec = mel_spec[:target_length] if length > target_length else np.pad(
65
+ mel_spec, ((0, target_length - length), (0, 0)), mode='constant', constant_values=mel_spec.min()
66
+ )
67
+
68
+ # normalize
69
+ mel_spec = normalize(mel_spec) # (2048, 128)
70
+
71
+ display_image(mel_spec.T, figsize=(10, 4))
72
+
73
+ # Model
74
+ mae = MaskedAutoencoderViT(
75
+ img_size=(2048, 128),
76
+ patch_size=16,
77
+ in_chans=1,
78
+ embed_dim=768,
79
+ depth=12,
80
+ num_heads=12,
81
+ decoder_mode=1,
82
+ no_shift=False,
83
+ decoder_embed_dim=512,
84
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
85
+ norm_pix_loss=False,
86
+ pos_trainable=False,
87
+ )
88
+
89
+ # Load pre-trained weights
90
+ ckpt_path = 'music-mae-32kHz.pth'
91
+ mae.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
92
+
93
+ device = 'cpu' # 'cuda'
94
+ mae.to(device)
95
+
96
+ x = torch.from_numpy(mel_spec).unsqueeze(0).unsqueeze(0).to(device) # (1, 1, 2048, 128)
97
+ mse_loss, y, mask = mae(x, mask_ratio=0.7) # y: (1, 1024, 256), mask: (1, 1024)
98
+
99
+ y[mask == 0.] = mae.patchify(x)[mask == 0.]
100
+ x_reconstructed = mae.unpatchify(y).squeeze(0).squeeze(0).detach().numpy()
101
+
102
+ print(f'mse_loss: {mse_loss.item()}')
103
+ display_image(x_reconstructed.T, figsize=(10, 4))