File size: 5,527 Bytes
3133fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import itertools
import math
import unittest

import torch
from pytorchvideo.layers import PositionalEncoding, SpatioTemporalClsPositionalEncoding


class TestPositionalEncoding(unittest.TestCase):
    def setUp(self):
        super().setUp()
        torch.set_rng_state(torch.manual_seed(42).get_state())

        self.batch_size = 4
        self.seq_len = 16
        self.feature_dim = 8
        self.fake_input = torch.randn(
            (self.batch_size, self.seq_len, self.feature_dim)
        ).float()
        lengths = torch.Tensor([16, 0, 14, 15, 16, 16, 16, 16])
        self.mask = torch.lt(
            torch.arange(self.seq_len)[None, :], lengths[:, None].long()
        )

    def test_positional_encoding(self):
        model = PositionalEncoding(self.feature_dim, self.seq_len)
        output = model(self.fake_input)
        delta = output - self.fake_input

        pe = torch.zeros(self.seq_len, self.feature_dim, dtype=torch.float)
        position = torch.arange(0, self.seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, self.feature_dim, 2).float()
            * (-math.log(10000.0) / self.feature_dim)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        for n in range(0, self.batch_size):
            self.assertTrue(torch.allclose(delta[n], pe, atol=1e-6))

    def test_positional_encoding_with_different_pe_and_data_dimensions(self):
        """Test that model executes even if input data dimensions
        differs from the dimension of initialized postional encoding model"""

        # When self.seq_len < positional_encoding_seq_len, pe is added to input
        positional_encoding_seq_len = self.seq_len * 3
        model = PositionalEncoding(self.feature_dim, positional_encoding_seq_len)
        output = model(self.fake_input)

        delta = output - self.fake_input
        pe = torch.zeros(self.seq_len, self.feature_dim, dtype=torch.float)
        position = torch.arange(0, self.seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, self.feature_dim, 2).float()
            * (-math.log(10000.0) / self.feature_dim)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        for n in range(0, self.batch_size):
            self.assertTrue(torch.allclose(delta[n], pe, atol=1e-6))

        # When self.seq_len > positional_encoding_seq_len, assertion error
        positional_encoding_seq_len = self.seq_len // 2
        model = PositionalEncoding(self.feature_dim, positional_encoding_seq_len)
        with self.assertRaises(AssertionError):
            output = model(self.fake_input)

    def test_SpatioTemporalClsPositionalEncoding(self):
        # Test with cls token.
        batch_dim = 4
        dim = 16
        video_shape = (1, 2, 4)
        video_sum = video_shape[0] * video_shape[1] * video_shape[2]
        has_cls = True
        model = SpatioTemporalClsPositionalEncoding(
            embed_dim=dim,
            patch_embed_shape=video_shape,
            has_cls=has_cls,
        )
        fake_input = torch.rand(batch_dim, video_sum, dim)
        output = model(fake_input)
        output_gt_shape = (batch_dim, video_sum + 1, dim)
        self.assertEqual(tuple(output.shape), output_gt_shape)

    def test_SpatioTemporalClsPositionalEncoding_nocls(self):
        # Test without cls token.
        batch_dim = 4
        dim = 16
        video_shape = (1, 2, 4)
        video_sum = video_shape[0] * video_shape[1] * video_shape[2]
        has_cls = False
        model = SpatioTemporalClsPositionalEncoding(
            embed_dim=dim,
            patch_embed_shape=video_shape,
            has_cls=has_cls,
        )
        fake_input = torch.rand(batch_dim, video_sum, dim)
        output = model(fake_input)
        output_gt_shape = (batch_dim, video_sum, dim)
        self.assertEqual(tuple(output.shape), output_gt_shape)

    def test_SpatioTemporalClsPositionalEncoding_mismatch(self):
        # Mismatch in dimension for patch_embed_shape.
        with self.assertRaises(AssertionError):
            SpatioTemporalClsPositionalEncoding(
                embed_dim=16,
                patch_embed_shape=(1, 2),
            )

    def test_SpatioTemporalClsPositionalEncoding_scriptable(self):
        iter_embed_dim = [1, 2, 4, 32]
        iter_patch_embed_shape = [(1, 1, 1), (1, 2, 4), (32, 16, 1)]
        iter_sep_pos_embed = [True, False]
        iter_has_cls = [True, False]

        for (
            embed_dim,
            patch_embed_shape,
            sep_pos_embed,
            has_cls,
        ) in itertools.product(
            iter_embed_dim,
            iter_patch_embed_shape,
            iter_sep_pos_embed,
            iter_has_cls,
        ):
            stcpe = SpatioTemporalClsPositionalEncoding(
                embed_dim=embed_dim,
                patch_embed_shape=patch_embed_shape,
                sep_pos_embed=sep_pos_embed,
                has_cls=has_cls,
            )
            stcpe_scripted = torch.jit.script(stcpe)
            batch_dim = 4
            video_dim = math.prod(patch_embed_shape)
            fake_input = torch.rand(batch_dim, video_dim, embed_dim)
            expected = stcpe(fake_input)
            actual = stcpe_scripted(fake_input)
            torch.testing.assert_allclose(expected, actual)