File size: 5,527 Bytes
3133fdb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import itertools
import math
import unittest
import torch
from pytorchvideo.layers import PositionalEncoding, SpatioTemporalClsPositionalEncoding
class TestPositionalEncoding(unittest.TestCase):
def setUp(self):
super().setUp()
torch.set_rng_state(torch.manual_seed(42).get_state())
self.batch_size = 4
self.seq_len = 16
self.feature_dim = 8
self.fake_input = torch.randn(
(self.batch_size, self.seq_len, self.feature_dim)
).float()
lengths = torch.Tensor([16, 0, 14, 15, 16, 16, 16, 16])
self.mask = torch.lt(
torch.arange(self.seq_len)[None, :], lengths[:, None].long()
)
def test_positional_encoding(self):
model = PositionalEncoding(self.feature_dim, self.seq_len)
output = model(self.fake_input)
delta = output - self.fake_input
pe = torch.zeros(self.seq_len, self.feature_dim, dtype=torch.float)
position = torch.arange(0, self.seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.feature_dim, 2).float()
* (-math.log(10000.0) / self.feature_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
for n in range(0, self.batch_size):
self.assertTrue(torch.allclose(delta[n], pe, atol=1e-6))
def test_positional_encoding_with_different_pe_and_data_dimensions(self):
"""Test that model executes even if input data dimensions
differs from the dimension of initialized postional encoding model"""
# When self.seq_len < positional_encoding_seq_len, pe is added to input
positional_encoding_seq_len = self.seq_len * 3
model = PositionalEncoding(self.feature_dim, positional_encoding_seq_len)
output = model(self.fake_input)
delta = output - self.fake_input
pe = torch.zeros(self.seq_len, self.feature_dim, dtype=torch.float)
position = torch.arange(0, self.seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.feature_dim, 2).float()
* (-math.log(10000.0) / self.feature_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
for n in range(0, self.batch_size):
self.assertTrue(torch.allclose(delta[n], pe, atol=1e-6))
# When self.seq_len > positional_encoding_seq_len, assertion error
positional_encoding_seq_len = self.seq_len // 2
model = PositionalEncoding(self.feature_dim, positional_encoding_seq_len)
with self.assertRaises(AssertionError):
output = model(self.fake_input)
def test_SpatioTemporalClsPositionalEncoding(self):
# Test with cls token.
batch_dim = 4
dim = 16
video_shape = (1, 2, 4)
video_sum = video_shape[0] * video_shape[1] * video_shape[2]
has_cls = True
model = SpatioTemporalClsPositionalEncoding(
embed_dim=dim,
patch_embed_shape=video_shape,
has_cls=has_cls,
)
fake_input = torch.rand(batch_dim, video_sum, dim)
output = model(fake_input)
output_gt_shape = (batch_dim, video_sum + 1, dim)
self.assertEqual(tuple(output.shape), output_gt_shape)
def test_SpatioTemporalClsPositionalEncoding_nocls(self):
# Test without cls token.
batch_dim = 4
dim = 16
video_shape = (1, 2, 4)
video_sum = video_shape[0] * video_shape[1] * video_shape[2]
has_cls = False
model = SpatioTemporalClsPositionalEncoding(
embed_dim=dim,
patch_embed_shape=video_shape,
has_cls=has_cls,
)
fake_input = torch.rand(batch_dim, video_sum, dim)
output = model(fake_input)
output_gt_shape = (batch_dim, video_sum, dim)
self.assertEqual(tuple(output.shape), output_gt_shape)
def test_SpatioTemporalClsPositionalEncoding_mismatch(self):
# Mismatch in dimension for patch_embed_shape.
with self.assertRaises(AssertionError):
SpatioTemporalClsPositionalEncoding(
embed_dim=16,
patch_embed_shape=(1, 2),
)
def test_SpatioTemporalClsPositionalEncoding_scriptable(self):
iter_embed_dim = [1, 2, 4, 32]
iter_patch_embed_shape = [(1, 1, 1), (1, 2, 4), (32, 16, 1)]
iter_sep_pos_embed = [True, False]
iter_has_cls = [True, False]
for (
embed_dim,
patch_embed_shape,
sep_pos_embed,
has_cls,
) in itertools.product(
iter_embed_dim,
iter_patch_embed_shape,
iter_sep_pos_embed,
iter_has_cls,
):
stcpe = SpatioTemporalClsPositionalEncoding(
embed_dim=embed_dim,
patch_embed_shape=patch_embed_shape,
sep_pos_embed=sep_pos_embed,
has_cls=has_cls,
)
stcpe_scripted = torch.jit.script(stcpe)
batch_dim = 4
video_dim = math.prod(patch_embed_shape)
fake_input = torch.rand(batch_dim, video_dim, embed_dim)
expected = stcpe(fake_input)
actual = stcpe_scripted(fake_input)
torch.testing.assert_allclose(expected, actual)
|