File size: 5,034 Bytes
3133fdb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import copy
import unittest
import torch
import torch.nn
from pytorchvideo.layers import make_multilayer_perceptron, PositionalEncoding
from pytorchvideo.models.masked_multistream import (
LearnMaskedDefault,
LSTM,
MaskedSequential,
MaskedTemporalPooling,
TransposeMultiheadAttention,
TransposeTransformerEncoder,
)
class TestMaskedMultiStream(unittest.TestCase):
def setUp(self):
super().setUp()
torch.set_rng_state(torch.manual_seed(42).get_state())
def test_masked_multistream_model(self):
feature_dim = 8
mlp, out_dim = make_multilayer_perceptron([feature_dim, 2])
input_stream = MaskedSequential(
PositionalEncoding(feature_dim),
TransposeMultiheadAttention(feature_dim),
MaskedTemporalPooling(method="avg"),
torch.nn.LayerNorm(feature_dim),
mlp,
LearnMaskedDefault(out_dim),
)
seq_len = 10
input_tensor = torch.rand([4, seq_len, feature_dim])
mask = _lengths2mask(
torch.tensor([seq_len, seq_len, seq_len, seq_len]), input_tensor.shape[1]
)
output = input_stream(input=input_tensor, mask=mask)
self.assertEqual(output.shape, torch.Size([4, out_dim]))
def test_masked_temporal_pooling(self):
fake_input = torch.Tensor(
[[[4, -2], [3, 0]], [[0, 2], [4, 3]], [[3, 1], [5, 2]]]
).float()
valid_lengths = torch.Tensor([2, 1, 0]).int()
valid_mask = _lengths2mask(valid_lengths, fake_input.shape[1])
expected_output_for_method = {
"max": torch.Tensor([[4, 0], [0, 2], [0, 0]]).float(),
"avg": torch.Tensor([[3.5, -1], [0, 2], [0, 0]]).float(),
"sum": torch.Tensor([[7, -2], [0, 2], [0, 0]]).float(),
}
for method, expected_output in expected_output_for_method.items():
model = MaskedTemporalPooling(method)
output = model(copy.deepcopy(fake_input), mask=valid_mask)
self.assertTrue(torch.equal(output, expected_output))
def test_transpose_attention(self):
feature_dim = 8
seq_len = 10
fake_input = torch.rand([4, seq_len, feature_dim])
mask = _lengths2mask(
torch.tensor([seq_len, seq_len, seq_len, seq_len]), fake_input.shape[1]
)
model = TransposeMultiheadAttention(feature_dim, num_heads=2)
output = model(fake_input, mask=mask)
self.assertTrue(output.shape, fake_input.shape)
def test_masked_lstm(self):
feature_dim = 8
seq_len = 10
fake_input = torch.rand([4, seq_len, feature_dim])
mask = _lengths2mask(
torch.tensor([seq_len, seq_len, seq_len, seq_len]), fake_input.shape[1]
)
hidden_dim = 128
model = LSTM(feature_dim, hidden_dim=hidden_dim, bidirectional=False)
output = model(fake_input, mask=mask)
self.assertTrue(output.shape, (fake_input.shape[0], hidden_dim))
model = LSTM(feature_dim, hidden_dim=hidden_dim, bidirectional=True)
output = model(fake_input, mask=mask)
self.assertTrue(output.shape, (fake_input.shape[0], hidden_dim * 2))
def test_masked_transpose_transformer_encoder(self):
feature_dim = 8
seq_len = 10
fake_input = torch.rand([4, seq_len, feature_dim])
mask = _lengths2mask(
torch.tensor([seq_len, seq_len, seq_len, seq_len]), fake_input.shape[1]
)
model = TransposeTransformerEncoder(feature_dim)
output = model(fake_input, mask=mask)
self.assertEqual(output.shape, (fake_input.shape[0], feature_dim))
def test_learn_masked_default(self):
feature_dim = 8
seq_len = 10
fake_input = torch.rand([4, feature_dim])
# All valid mask
all_valid_mask = _lengths2mask(
torch.tensor([seq_len, seq_len, seq_len, seq_len]), fake_input.shape[1]
)
model = LearnMaskedDefault(feature_dim)
output = model(fake_input, mask=all_valid_mask)
self.assertTrue(output.equal(fake_input))
# No valid mask
no_valid_mask = _lengths2mask(torch.tensor([0, 0, 0, 0]), fake_input.shape[1])
model = LearnMaskedDefault(feature_dim)
output = model(fake_input, mask=no_valid_mask)
self.assertTrue(output.equal(model._learned_defaults.repeat(4, 1)))
# Half valid mask
half_valid_mask = _lengths2mask(torch.tensor([1, 1, 0, 0]), fake_input.shape[1])
model = LearnMaskedDefault(feature_dim)
output = model(fake_input, mask=half_valid_mask)
self.assertTrue(output[:2].equal(fake_input[:2]))
self.assertTrue(output[2:].equal(model._learned_defaults.repeat(2, 1)))
def _lengths2mask(lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
return torch.lt(
torch.arange(seq_len, device=lengths.device)[None, :], lengths[:, None].long()
)
|