Spaces:
Running
on
L4
Running
on
L4
| # from https://github.com/isaaccorley/jax-enhance | |
| from functools import partial | |
| from typing import Any, Sequence, Callable | |
| import jax.numpy as jnp | |
| import flax.linen as nn | |
| from flax.core.frozen_dict import freeze | |
| import einops | |
| class PixelShuffle(nn.Module): | |
| scale_factor: int | |
| def setup(self): | |
| self.layer = partial( | |
| einops.rearrange, | |
| pattern="b h w (c h2 w2) -> b (h h2) (w w2) c", | |
| h2=self.scale_factor, | |
| w2=self.scale_factor | |
| ) | |
| def __call__(self, x: jnp.ndarray) -> jnp.ndarray: | |
| return self.layer(x) | |
| class ResidualBlock(nn.Module): | |
| channels: int | |
| kernel_size: Sequence[int] | |
| res_scale: float | |
| activation: Callable | |
| dtype: Any = jnp.float32 | |
| def setup(self): | |
| self.body = nn.Sequential([ | |
| nn.Conv(features=self.channels, kernel_size=self.kernel_size, dtype=self.dtype), | |
| self.activation, | |
| nn.Conv(features=self.channels, kernel_size=self.kernel_size, dtype=self.dtype), | |
| ]) | |
| def __call__(self, x: jnp.ndarray) -> jnp.ndarray: | |
| return x + self.body(x) | |
| class UpsampleBlock(nn.Module): | |
| num_upsamples: int | |
| channels: int | |
| kernel_size: Sequence[int] | |
| dtype: Any = jnp.float32 | |
| def setup(self): | |
| layers = [] | |
| for _ in range(self.num_upsamples): | |
| layers.extend([ | |
| nn.Conv(features=self.channels * 2 ** 2, kernel_size=self.kernel_size, dtype=self.dtype), | |
| PixelShuffle(scale_factor=2), | |
| ]) | |
| self.layers = layers | |
| def __call__(self, x: jnp.ndarray) -> jnp.ndarray: | |
| for layer in self.layers: | |
| x = layer(x) | |
| return x | |
| class EDSR(nn.Module): | |
| """Enhanced Deep Residual Networks for Single Image Super-Resolution https://arxiv.org/pdf/1707.02921v1.pdf""" | |
| scale_factor: int | |
| channels: int = 3 | |
| num_blocks: int = 32 | |
| num_feats: int = 256 | |
| dtype: Any = jnp.float32 | |
| def setup(self): | |
| # pre res blocks layer | |
| self.head = nn.Sequential([nn.Conv(features=self.num_feats, kernel_size=(3, 3), dtype=self.dtype)]) | |
| # res blocks | |
| res_blocks = [ | |
| ResidualBlock(channels=self.num_feats, kernel_size=(3, 3), res_scale=0.1, activation=nn.relu, dtype=self.dtype) | |
| for i in range(self.num_blocks) | |
| ] | |
| res_blocks.append(nn.Conv(features=self.num_feats, kernel_size=(3, 3), dtype=self.dtype)) | |
| self.body = nn.Sequential(res_blocks) | |
| def __call__(self, x: jnp.ndarray, _=None) -> jnp.ndarray: | |
| x = self.head(x) | |
| x = x + self.body(x) | |
| return x | |
| def convert_edsr_checkpoint(torch_dict, no_upsampling=True): | |
| def convert(in_dict): | |
| top_keys = set([k.split('.')[0] for k in in_dict.keys()]) | |
| leaves = set([k for k in in_dict.keys() if '.' not in k]) | |
| # convert leaves | |
| out_dict = {} | |
| for l in leaves: | |
| if l == 'weight': | |
| out_dict['kernel'] = jnp.asarray(in_dict[l]).transpose((2, 3, 1, 0)) | |
| elif l == 'bias': | |
| out_dict[l] = jnp.asarray(in_dict[l]) | |
| else: | |
| out_dict[l] = in_dict[l] | |
| for top_key in top_keys.difference(leaves): | |
| new_top_key = 'layers_' + top_key if top_key.isdigit() else top_key | |
| out_dict[new_top_key] = convert( | |
| {k[len(top_key) + 1:]: v for k, v in in_dict.items() if k.startswith(top_key)}) | |
| return out_dict | |
| converted = convert(torch_dict) | |
| # remove unwanted keys | |
| if no_upsampling: | |
| del converted['tail'] | |
| for k in ('add_mean', 'sub_mean'): | |
| del converted[k] | |
| return freeze(converted) | |