Spaces:
Running
on
L4
Running
on
L4
| import math | |
| from typing import Callable, Optional, Iterable | |
| import numpy as np | |
| import jax | |
| import jax.numpy as jnp | |
| import flax.linen as nn | |
| from jaxtyping import Array | |
| def trunc_normal(mean=0., std=1., a=-2., b=2., dtype=jnp.float32) -> Callable: | |
| """Truncated normal initialization function""" | |
| def init(key, shape, dtype=dtype) -> Array: | |
| # https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/weight_init.py | |
| def norm_cdf(x): | |
| # Computes standard normal cumulative distribution function | |
| return (1. + math.erf(x / math.sqrt(2.))) / 2. | |
| l = norm_cdf((a - mean) / std) | |
| u = norm_cdf((b - mean) / std) | |
| out = jax.random.uniform(key, shape, dtype=dtype, minval=2 * l - 1, maxval=2 * u - 1) | |
| out = jax.scipy.special.erfinv(out) * std * math.sqrt(2.) + mean | |
| return jnp.clip(out, a, b) | |
| return init | |
| def Dense(features, use_bias=True, kernel_init=trunc_normal(std=.02), bias_init=nn.initializers.zeros): | |
| return nn.Dense(features, use_bias=use_bias, kernel_init=kernel_init, bias_init=bias_init) | |
| def LayerNorm(): | |
| """torch LayerNorm uses larger epsilon by default""" | |
| return nn.LayerNorm(epsilon=1e-05) | |
| class Mlp(nn.Module): | |
| in_features: int | |
| hidden_features: int = None | |
| out_features: int = None | |
| act_layer: Callable = nn.gelu | |
| drop: float = 0.0 | |
| def __call__(self, x, training: bool): | |
| x = nn.Dense(self.hidden_features or self.in_features)(x) | |
| x = self.act_layer(x) | |
| x = nn.Dropout(self.drop, deterministic=not training)(x) | |
| x = nn.Dense(self.out_features or self.in_features)(x) | |
| x = nn.Dropout(self.drop, deterministic=not training)(x) | |
| return x | |
| def window_partition(x, window_size: int): | |
| """ | |
| Args: | |
| x: (B, H, W, C) | |
| window_size (int): window size | |
| Returns: | |
| windows: (num_windows*B, window_size, window_size, C) | |
| """ | |
| B, H, W, C = x.shape | |
| x = x.reshape((B, H // window_size, window_size, W // window_size, window_size, C)) | |
| windows = x.transpose((0, 1, 3, 2, 4, 5)).reshape((-1, window_size, window_size, C)) | |
| return windows | |
| def window_reverse(windows, window_size: int, H: int, W: int): | |
| """ | |
| Args: | |
| windows: (num_windows*B, window_size, window_size, C) | |
| window_size (int): Window size | |
| H (int): Height of image | |
| W (int): Width of image | |
| Returns: | |
| x: (B, H, W, C) | |
| """ | |
| B = int(windows.shape[0] / (H * W / window_size / window_size)) | |
| x = windows.reshape((B, H // window_size, W // window_size, window_size, window_size, -1)) | |
| x = x.transpose((0, 1, 3, 2, 4, 5)).reshape((B, H, W, -1)) | |
| return x | |
| class DropPath(nn.Module): | |
| """ | |
| Implementation referred from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py | |
| """ | |
| dropout_prob: float = 0.1 | |
| deterministic: Optional[bool] = None | |
| def __call__(self, input, training): | |
| if not training: | |
| return input | |
| keep_prob = 1 - self.dropout_prob | |
| shape = (input.shape[0],) + (1,) * (input.ndim - 1) | |
| rng = self.make_rng("dropout") | |
| random_tensor = keep_prob + jax.random.uniform(rng, shape) | |
| random_tensor = jnp.floor(random_tensor) | |
| return jnp.divide(input, keep_prob) * random_tensor | |
| class WindowAttention(nn.Module): | |
| dim: int | |
| window_size: Iterable[int] | |
| num_heads: int | |
| qkv_bias: bool = True | |
| qk_scale: Optional[float] = None | |
| att_drop: float = 0.0 | |
| proj_drop: float = 0.0 | |
| def make_rel_pos_index(self): | |
| h_indices = np.arange(0, self.window_size[0]) | |
| w_indices = np.arange(0, self.window_size[1]) | |
| indices = np.stack(np.meshgrid(w_indices, h_indices, indexing="ij")) | |
| flatten_indices = np.reshape(indices, (2, -1)) | |
| relative_indices = flatten_indices[:, :, None] - flatten_indices[:, None, :] | |
| relative_indices = np.transpose(relative_indices, (1, 2, 0)) | |
| relative_indices[:, :, 0] += self.window_size[0] - 1 | |
| relative_indices[:, :, 1] += self.window_size[1] - 1 | |
| relative_indices[:, :, 0] *= 2 * self.window_size[1] - 1 | |
| relative_pos_index = np.sum(relative_indices, -1) | |
| return relative_pos_index | |
| def __call__(self, inputs, mask, training): | |
| rpbt = self.param( | |
| "relative_position_bias_table", | |
| trunc_normal(std=.02), | |
| ( | |
| (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), | |
| self.num_heads, | |
| ), | |
| ) | |
| #relative_pos_index = self.variable( | |
| # "variables", "relative_position_index", self.get_rel_pos_index | |
| #) | |
| batch, n, channels = inputs.shape | |
| qkv = nn.Dense(self.dim * 3, use_bias=self.qkv_bias, name="qkv")(inputs) | |
| qkv = qkv.reshape(batch, n, 3, self.num_heads, channels // self.num_heads) | |
| qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4)) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| scale = self.qk_scale or (self.dim // self.num_heads) ** -0.5 | |
| q = q * scale | |
| att = q @ jnp.swapaxes(k, -2, -1) | |
| rel_pos_bias = jnp.reshape( | |
| rpbt[np.reshape(self.make_rel_pos_index(), (-1))], | |
| ( | |
| self.window_size[0] * self.window_size[1], | |
| self.window_size[0] * self.window_size[1], | |
| -1, | |
| ), | |
| ) | |
| rel_pos_bias = jnp.transpose(rel_pos_bias, (2, 0, 1)) | |
| att += jnp.expand_dims(rel_pos_bias, 0) | |
| if mask is not None: | |
| att = jnp.reshape( | |
| att, (batch // mask.shape[0], mask.shape[0], self.num_heads, n, n) | |
| ) | |
| att = att + jnp.expand_dims(jnp.expand_dims(mask, 1), 0) | |
| att = jnp.reshape(att, (-1, self.num_heads, n, n)) | |
| att = jax.nn.softmax(att) | |
| else: | |
| att = jax.nn.softmax(att) | |
| att = nn.Dropout(self.att_drop)(att, deterministic=not training) | |
| x = jnp.reshape(jnp.swapaxes(att @ v, 1, 2), (batch, n, channels)) | |
| x = nn.Dense(self.dim, name="proj")(x) | |
| x = nn.Dropout(self.proj_drop)(x, deterministic=not training) | |
| return x | |
| class SwinTransformerBlock(nn.Module): | |
| dim: int | |
| input_resolution: tuple[int] | |
| num_heads: int | |
| window_size: int = 7 | |
| shift_size: int = 0 | |
| mlp_ratio: float = 4. | |
| qkv_bias: bool = True | |
| qk_scale: Optional[float] = None | |
| drop: float = 0. | |
| attn_drop: float = 0. | |
| drop_path: float = 0. | |
| act_layer: Callable = nn.activation.gelu | |
| norm_layer: Callable = LayerNorm | |
| def make_att_mask(shift_size, window_size, height, width): | |
| if shift_size > 0: | |
| mask = jnp.zeros([1, height, width, 1]) | |
| h_slices = ( | |
| slice(0, -window_size), | |
| slice(-window_size, -shift_size), | |
| slice(-shift_size, None), | |
| ) | |
| w_slices = ( | |
| slice(0, -window_size), | |
| slice(-window_size, -shift_size), | |
| slice(-shift_size, None), | |
| ) | |
| count = 0 | |
| for h in h_slices: | |
| for w in w_slices: | |
| mask = mask.at[:, h, w, :].set(count) | |
| count += 1 | |
| mask_windows = window_partition(mask, window_size) | |
| mask_windows = jnp.reshape(mask_windows, (-1, window_size * window_size)) | |
| att_mask = jnp.expand_dims(mask_windows, 1) - jnp.expand_dims(mask_windows, 2) | |
| att_mask = jnp.where(att_mask != 0.0, float(-100.0), att_mask) | |
| att_mask = jnp.where(att_mask == 0.0, float(0.0), att_mask) | |
| else: | |
| att_mask = None | |
| return att_mask | |
| def __call__(self, x, x_size, training): | |
| H, W = x_size | |
| B, L, C = x.shape | |
| if min(self.input_resolution) <= self.window_size: | |
| # if window size is larger than input resolution, we don't partition windows | |
| self.shift_size = 0 | |
| self.window_size = min(self.input_resolution) | |
| assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" | |
| shortcut = x | |
| x = self.norm_layer()(x) | |
| x = x.reshape((B, H, W, C)) | |
| # cyclic shift | |
| if self.shift_size > 0: | |
| shifted_x = jnp.roll(x, (-self.shift_size, -self.shift_size), axis=(1, 2)) | |
| else: | |
| shifted_x = x | |
| # partition windows | |
| x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C | |
| x_windows = x_windows.reshape((-1, self.window_size * self.window_size, C)) # nW*B, window_size*window_size, C | |
| #attn_mask = self.variable( | |
| # "variables", | |
| # "attn_mask", | |
| # self.get_att_mask, | |
| # self.shift_size, | |
| # self.window_size, | |
| # self.input_resolution[0], | |
| # self.input_resolution[1] | |
| #) | |
| attn_mask = self.make_att_mask(self.shift_size, self.window_size, *self.input_resolution) | |
| attn = WindowAttention(self.dim, (self.window_size, self.window_size), self.num_heads, | |
| self.qkv_bias, self.qk_scale, self.attn_drop, self.drop) | |
| if self.input_resolution == x_size: | |
| attn_windows = attn(x_windows, attn_mask, training) # nW*B, window_size*window_size, C | |
| else: | |
| # test time | |
| assert not training | |
| test_mask = self.make_att_mask(self.shift_size, self.window_size, *x_size) | |
| attn_windows = attn(x_windows, test_mask, training=False) | |
| # merge windows | |
| attn_windows = attn_windows.reshape((-1, self.window_size, self.window_size, C)) | |
| shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C | |
| # reverse cyclic shift | |
| if self.shift_size > 0: | |
| x = jnp.roll(shifted_x, (self.shift_size, self.shift_size), axis=(1, 2)) | |
| else: | |
| x = shifted_x | |
| x = x.reshape((B, H * W, C)) | |
| # FFN | |
| x = shortcut + DropPath(self.drop_path)(x, training) | |
| norm = self.norm_layer()(x) | |
| mlp = Mlp(in_features=self.dim, hidden_features=int(self.dim * self.mlp_ratio), | |
| act_layer=self.act_layer, drop=self.drop)(norm, training) | |
| x = x + DropPath(self.drop_path)(mlp, training) | |
| return x | |
| class PatchMerging(nn.Module): | |
| inp_res: Iterable[int] | |
| dim: int | |
| norm_layer: Callable = LayerNorm | |
| def __call__(self, inputs): | |
| batch, n, channels = inputs.shape | |
| height, width = self.inp_res[0], self.inp_res[1] | |
| x = jnp.reshape(inputs, (batch, height, width, channels)) | |
| x0 = x[:, 0::2, 0::2, :] | |
| x1 = x[:, 1::2, 0::2, :] | |
| x2 = x[:, 0::2, 1::2, :] | |
| x3 = x[:, 1::2, 1::2, :] | |
| x = jnp.concatenate([x0, x1, x2, x3], axis=-1) | |
| x = jnp.reshape(x, (batch, -1, 4 * channels)) | |
| x = self.norm_layer()(x) | |
| x = nn.Dense(2 * self.dim, use_bias=False)(x) | |
| return x | |
| class BasicLayer(nn.Module): | |
| dim: int | |
| input_resolution: int | |
| depth: int | |
| num_heads: int | |
| window_size: int | |
| mlp_ratio: float = 4. | |
| qkv_bias: bool = True | |
| qk_scale: Optional[float] = None | |
| drop: float = 0. | |
| attn_drop: float = 0. | |
| drop_path: float = 0. | |
| norm_layer: Callable = LayerNorm | |
| downsample: Optional[Callable] = None | |
| def __call__(self, x, x_size, training): | |
| for i in range(self.depth): | |
| x = SwinTransformerBlock( | |
| self.dim, | |
| self.input_resolution, | |
| self.num_heads, | |
| self.window_size, | |
| 0 if (i % 2 == 0) else self.window_size // 2, | |
| self.mlp_ratio, | |
| self.qkv_bias, | |
| self.qk_scale, | |
| self.drop, | |
| self.attn_drop, | |
| self.drop_path[i] if isinstance(self.drop_path, (list, tuple)) else self.drop_path, | |
| norm_layer=self.norm_layer | |
| )(x, x_size, training) | |
| if self.downsample is not None: | |
| x = self.downsample(self.input_resolution, dim=self.dim, norm_layer=self.norm_layer)(x) | |
| return x | |
| class RSTB(nn.Module): | |
| dim: int | |
| input_resolution: int | |
| depth: int | |
| num_heads: int | |
| window_size: int | |
| mlp_ratio: float = 4. | |
| qkv_bias: bool = True | |
| qk_scale: Optional[float] = None | |
| drop: float = 0. | |
| attn_drop: float = 0. | |
| drop_path: float = 0. | |
| norm_layer: Callable = LayerNorm | |
| downsample: Optional[Callable] = None | |
| img_size: int = 224, | |
| patch_size: int = 4, | |
| resi_connection: str = '1conv' | |
| def __call__(self, x, x_size, training): | |
| res = x | |
| x = BasicLayer(dim=self.dim, | |
| input_resolution=self.input_resolution, | |
| depth=self.depth, | |
| num_heads=self.num_heads, | |
| window_size=self.window_size, | |
| mlp_ratio=self.mlp_ratio, | |
| qkv_bias=self.qkv_bias, qk_scale=self.qk_scale, | |
| drop=self.drop, attn_drop=self.attn_drop, | |
| drop_path=self.drop_path, | |
| norm_layer=self.norm_layer, | |
| downsample=self.downsample)(x, x_size, training) | |
| x = PatchUnEmbed(embed_dim=self.dim)(x, x_size) | |
| # resi_connection == '1conv': | |
| x = nn.Conv(self.dim, (3, 3))(x) | |
| x = PatchEmbed()(x) | |
| return x + res | |
| class PatchEmbed(nn.Module): | |
| norm_layer: Optional[Callable] = None | |
| def __call__(self, x): | |
| x = x.reshape((x.shape[0], -1, x.shape[-1])) # B Ph Pw C -> B Ph*Pw C | |
| if self.norm_layer is not None: | |
| x = self.norm_layer()(x) | |
| return x | |
| class PatchUnEmbed(nn.Module): | |
| embed_dim: int = 96 | |
| def __call__(self, x, x_size): | |
| B, HW, C = x.shape | |
| x = x.reshape((B, x_size[0], x_size[1], self.embed_dim)) | |
| return x | |
| class SwinIR(nn.Module): | |
| r""" SwinIR JAX implementation | |
| Args: | |
| img_size (int | tuple(int)): Input image size. Default 64 | |
| patch_size (int | tuple(int)): Patch size. Default: 1 | |
| in_chans (int): Number of input image channels. Default: 3 | |
| embed_dim (int): Patch embedding dimension. Default: 96 | |
| depths (tuple(int)): Depth of each Swin Transformer layer. | |
| num_heads (tuple(int)): Number of attention heads in different layers. | |
| window_size (int): Window size. Default: 7 | |
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 | |
| qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True | |
| qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None | |
| drop_rate (float): Dropout rate. Default: 0 | |
| attn_drop_rate (float): Attention dropout rate. Default: 0 | |
| drop_path_rate (float): Stochastic depth rate. Default: 0.1 | |
| norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. | |
| ape (bool): If True, add absolute position embedding to the patch embedding. Default: False | |
| patch_norm (bool): If True, add normalization after patch embedding. Default: True | |
| upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction | |
| img_range: Image range. 1. or 25I think5. | |
| """ | |
| img_size: int = 48 | |
| patch_size: int = 1 | |
| in_chans: int = 3 | |
| embed_dim: int = 180 | |
| depths: tuple = (6, 6, 6, 6, 6, 6) | |
| num_heads: tuple = (6, 6, 6, 6, 6, 6) | |
| window_size: int = 8 | |
| mlp_ratio: float = 2. | |
| qkv_bias: bool = True | |
| qk_scale: Optional[float] = None | |
| drop_rate: float = 0. | |
| attn_drop_rate: float = 0. | |
| drop_path_rate: float = 0.1 | |
| norm_layer: Callable = LayerNorm | |
| ape: bool = False | |
| patch_norm: bool = True | |
| upscale: int = 2 | |
| img_range: float = 1. | |
| num_feat: int = 64 | |
| def pad(self, x): | |
| _, h, w, _ = x.shape | |
| mod_pad_h = (self.window_size - h % self.window_size) % self.window_size | |
| mod_pad_w = (self.window_size - w % self.window_size) % self.window_size | |
| x = jnp.pad(x, ((0, 0), (0, mod_pad_h), (0, mod_pad_w), (0, 0)), 'reflect') | |
| return x | |
| def __call__(self, x, training): | |
| _, h_before, w_before, _ = x.shape | |
| x = self.pad(x) | |
| _, h, w, _ = x.shape | |
| patches_resolution = [self.img_size // self.patch_size] * 2 | |
| num_patches = patches_resolution[0] * patches_resolution[1] | |
| # conv_first | |
| x = nn.Conv(self.embed_dim, (3, 3))(x) | |
| res = x | |
| # feature extraction | |
| x_size = (h, w) | |
| x = PatchEmbed(self.norm_layer if self.patch_norm else None)(x) | |
| if self.ape: | |
| absolute_pos_embed = \ | |
| self.param('ape', trunc_normal(std=.02), (1, num_patches, self.embed_dim)) | |
| x = x + absolute_pos_embed | |
| x = nn.Dropout(self.drop_rate, deterministic=not training)(x) | |
| dpr = [x.item() for x in np.linspace(0, self.drop_path_rate, sum(self.depths))] | |
| for i_layer in range(len(self.depths)): | |
| x = RSTB( | |
| dim=self.embed_dim, | |
| input_resolution=(patches_resolution[0], patches_resolution[1]), | |
| depth=self.depths[i_layer], | |
| num_heads=self.num_heads[i_layer], | |
| window_size=self.window_size, | |
| mlp_ratio=self.mlp_ratio, | |
| qkv_bias=self.qkv_bias, qk_scale=self.qk_scale, | |
| drop=self.drop_rate, attn_drop=self.attn_drop_rate, | |
| drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])], | |
| norm_layer=self.norm_layer, | |
| downsample=None, | |
| img_size=self.img_size, | |
| patch_size=self.patch_size)(x, x_size, training) | |
| x = self.norm_layer()(x) # B L C | |
| x = PatchUnEmbed(self.embed_dim)(x, x_size) | |
| # conv_after_body | |
| x = nn.Conv(self.embed_dim, (3, 3))(x) | |
| x = x + res | |
| # conv_before_upsample | |
| x = nn.activation.leaky_relu(nn.Conv(self.num_feat, (3, 3))(x)) | |
| # revert padding | |
| x = x[:, :-(h - h_before) or None, :-(w - w_before) or None] | |
| return x | |