icefall/egs/librispeech/ASR/zipformer/test_subsampling.py
2023-07-25 14:46:18 +08:00

153 lines
4.6 KiB
Python
Executable File

#!/usr/bin/env python3
import torch
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
def test_conv2d_subsampling():
layer1_channels = 8
layer2_channels = 32
layer3_channels = 128
out_channels = 192
encoder_embed = Conv2dSubsampling(
in_channels=80,
out_channels=out_channels,
layer1_channels=layer1_channels,
layer2_channels=layer2_channels,
layer3_channels=layer3_channels,
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
)
N = 2
T = 200
num_features = 80
x = torch.rand(N, T, num_features)
x_copy = x.clone()
x = x.unsqueeze(1) # (N, 1, T, num_features)
x = encoder_embed.conv[0](x) # conv2d, in 1, out 8, kernel 3, padding (0,1)
assert x.shape == (N, layer1_channels, T - 2, num_features)
# (2, 8, 198, 80)
x = encoder_embed.conv[1](x) # scale grad
x = encoder_embed.conv[2](x) # balancer
x = encoder_embed.conv[3](x) # swooshR
x = encoder_embed.conv[4](x) # conv2d, in 8, out 32, kernel 3, stride 2
assert x.shape == (
N,
layer2_channels,
((T - 2) - 3) // 2 + 1,
(num_features - 3) // 2 + 1,
)
# (2, 32, 98, 39)
x = encoder_embed.conv[5](x) # balancer
x = encoder_embed.conv[6](x) # swooshR
# conv2d:
# in 32, out 128, kernel 3, stride (1, 2)
x = encoder_embed.conv[7](x)
assert x.shape == (
N,
layer3_channels,
(((T - 2) - 3) // 2 + 1) - 2,
(((num_features - 3) // 2 + 1) - 3) // 2 + 1,
)
# (2, 128, 96, 19)
x = encoder_embed.conv[8](x) # balancer
x = encoder_embed.conv[9](x) # swooshR
# (((T - 2) - 3) // 2 + 1) - 2
# = (T - 2) - 3) // 2 + 1 - 2
# = ((T - 2) - 3) // 2 - 1
# = (T - 2 - 3) // 2 - 1
# = (T - 5) // 2 - 1
# = (T - 7) // 2
assert x.shape[2] == (x_copy.shape[1] - 7) // 2
# (((num_features - 3) // 2 + 1) - 3) // 2 + 1,
# = ((num_features - 3) // 2 + 1 - 3) // 2 + 1,
# = ((num_features - 3) // 2 - 2) // 2 + 1,
# = (num_features - 3 - 4) // 2 // 2 + 1,
# = (num_features - 7) // 2 // 2 + 1,
# = (num_features - 7) // 4 + 1,
# = (num_features - 3) // 4
assert x.shape[3] == (x_copy.shape[2] - 3) // 4
assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4)
# Input shape to convnext is
#
# (N, layer3_channels, (T-7)//2, (num_features - 3)//4)
# conv2d: in layer3_channels, out layer3_channels, groups layer3_channels
# kernel_size 7, padding 3
x = encoder_embed.convnext.depthwise_conv(x)
assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4)
# conv2d: in layer3_channels, out hidden_ratio * layer3_channels, kernel_size 1
x = encoder_embed.convnext.pointwise_conv1(x)
assert x.shape == (N, layer3_channels * 3, (T - 7) // 2, (num_features - 3) // 4)
x = encoder_embed.convnext.hidden_balancer(x) # balancer
x = encoder_embed.convnext.activation(x) # swooshL
# conv2d: in hidden_ratio * layer3_channels, out layer3_channels, kernel 1
x = encoder_embed.convnext.pointwise_conv2(x)
assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4)
# bypass and layer drop, omitted here.
x = encoder_embed.convnext.out_balancer(x)
# Note: the input and output shape of ConvNeXt are the same
x = x.transpose(1, 2).reshape(N, (T - 7) // 2, -1)
assert x.shape == (N, (T - 7) // 2, layer3_channels * ((num_features - 3) // 4))
x = encoder_embed.out(x)
assert x.shape == (N, (T - 7) // 2, out_channels)
x = encoder_embed.out_whiten(x)
x = encoder_embed.out_norm(x)
# final layer is dropout
# test streaming forward
subsampling_factor = 2
cached_left_padding = encoder_embed.get_init_states(batch_size=N)
depthwise_conv_kernel_size = 7
pad_size = (depthwise_conv_kernel_size - 1) // 2
assert cached_left_padding.shape == (
N,
layer3_channels,
pad_size,
(num_features - 3) // 4,
)
chunk_size = 16
right_padding = pad_size * subsampling_factor
T = chunk_size * subsampling_factor + 7 + right_padding
x = torch.rand(N, T, num_features)
x_lens = torch.tensor([T] * N)
y, y_lens, next_cached_left_padding = encoder_embed.streaming_forward(
x, x_lens, cached_left_padding
)
assert y.shape == (N, chunk_size, out_channels), y.shape
assert next_cached_left_padding.shape == cached_left_padding.shape
assert y.shape[1] == y_lens[0] == y_lens[1]
def main():
test_conv2d_subsampling()
if __name__ == "__main__":
main()