|
|
|
|
|
|
|
|
import itertools |
|
|
import unittest |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from pytorchvideo.models.r2plus1d import ( |
|
|
create_2plus1d_bottleneck_block, |
|
|
create_r2plus1d, |
|
|
) |
|
|
from pytorchvideo.models.resnet import create_bottleneck_block |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
class TestR2plus1d(unittest.TestCase): |
|
|
def setUp(self): |
|
|
super().setUp() |
|
|
torch.set_rng_state(torch.manual_seed(42).get_state()) |
|
|
|
|
|
def test_create_r2plus1d(self): |
|
|
""" |
|
|
Test simple r2plus1d with different inputs. |
|
|
""" |
|
|
for input_channel, input_clip_length, input_crop_size in itertools.product( |
|
|
(3, 2), (4, 8), (56, 64) |
|
|
): |
|
|
stage_spatial_stride = (2, 2, 2, 2) |
|
|
stage_temporal_stride = (1, 1, 2, 2) |
|
|
|
|
|
total_spatial_stride = 2 * np.prod(stage_spatial_stride) |
|
|
total_temporal_stride = np.prod(stage_temporal_stride) |
|
|
head_pool_kernel_size = ( |
|
|
input_clip_length // total_temporal_stride, |
|
|
input_crop_size // total_spatial_stride, |
|
|
input_crop_size // total_spatial_stride, |
|
|
) |
|
|
|
|
|
model = create_r2plus1d( |
|
|
input_channel=input_channel, |
|
|
model_depth=50, |
|
|
model_num_class=400, |
|
|
dropout_rate=0.0, |
|
|
norm=nn.BatchNorm3d, |
|
|
activation=nn.ReLU, |
|
|
stem_dim_out=8, |
|
|
stem_conv_kernel_size=(1, 7, 7), |
|
|
stem_conv_stride=(1, 2, 2), |
|
|
stage_conv_b_kernel_size=((3, 3, 3),) * 4, |
|
|
stage_spatial_stride=stage_spatial_stride, |
|
|
stage_temporal_stride=stage_temporal_stride, |
|
|
stage_bottleneck=( |
|
|
create_bottleneck_block, |
|
|
create_2plus1d_bottleneck_block, |
|
|
create_2plus1d_bottleneck_block, |
|
|
create_2plus1d_bottleneck_block, |
|
|
), |
|
|
head_pool=nn.AvgPool3d, |
|
|
head_pool_kernel_size=head_pool_kernel_size, |
|
|
head_output_size=(1, 1, 1), |
|
|
head_activation=nn.Softmax, |
|
|
) |
|
|
|
|
|
|
|
|
for tensor in TestR2plus1d._get_inputs( |
|
|
input_channel, input_clip_length, input_crop_size |
|
|
): |
|
|
if tensor.shape[1] != input_channel: |
|
|
with self.assertRaises(RuntimeError): |
|
|
out = model(tensor) |
|
|
continue |
|
|
|
|
|
out = model(tensor) |
|
|
|
|
|
output_shape = out.shape |
|
|
output_shape_gt = (tensor.shape[0], 400) |
|
|
|
|
|
self.assertEqual( |
|
|
output_shape, |
|
|
output_shape_gt, |
|
|
"Output shape {} is different from expected shape {}".format( |
|
|
output_shape, output_shape_gt |
|
|
), |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _get_inputs( |
|
|
channel: int = 3, clip_length: int = 16, crop_size: int = 224 |
|
|
) -> torch.tensor: |
|
|
""" |
|
|
Provide different tensors as test cases. |
|
|
|
|
|
Yield: |
|
|
(torch.tensor): tensor as test case input. |
|
|
""" |
|
|
|
|
|
shapes = ( |
|
|
(1, channel, clip_length, crop_size, crop_size), |
|
|
(2, channel, clip_length, crop_size, crop_size), |
|
|
) |
|
|
for shape in shapes: |
|
|
yield torch.rand(shape) |
|
|
|