mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
153 lines
4.6 KiB
Python
Executable File
153 lines
4.6 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import torch
|
|
from scaling import ScheduledFloat
|
|
from subsampling import Conv2dSubsampling
|
|
|
|
|
|
def test_conv2d_subsampling():
|
|
layer1_channels = 8
|
|
layer2_channels = 32
|
|
layer3_channels = 128
|
|
|
|
out_channels = 192
|
|
encoder_embed = Conv2dSubsampling(
|
|
in_channels=80,
|
|
out_channels=out_channels,
|
|
layer1_channels=layer1_channels,
|
|
layer2_channels=layer2_channels,
|
|
layer3_channels=layer3_channels,
|
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
|
)
|
|
N = 2
|
|
T = 200
|
|
num_features = 80
|
|
x = torch.rand(N, T, num_features)
|
|
x_copy = x.clone()
|
|
|
|
x = x.unsqueeze(1) # (N, 1, T, num_features)
|
|
|
|
x = encoder_embed.conv[0](x) # conv2d, in 1, out 8, kernel 3, padding (0,1)
|
|
assert x.shape == (N, layer1_channels, T - 2, num_features)
|
|
# (2, 8, 198, 80)
|
|
|
|
x = encoder_embed.conv[1](x) # scale grad
|
|
x = encoder_embed.conv[2](x) # balancer
|
|
x = encoder_embed.conv[3](x) # swooshR
|
|
|
|
x = encoder_embed.conv[4](x) # conv2d, in 8, out 32, kernel 3, stride 2
|
|
assert x.shape == (
|
|
N,
|
|
layer2_channels,
|
|
((T - 2) - 3) // 2 + 1,
|
|
(num_features - 3) // 2 + 1,
|
|
)
|
|
# (2, 32, 98, 39)
|
|
|
|
x = encoder_embed.conv[5](x) # balancer
|
|
x = encoder_embed.conv[6](x) # swooshR
|
|
|
|
# conv2d:
|
|
# in 32, out 128, kernel 3, stride (1, 2)
|
|
x = encoder_embed.conv[7](x)
|
|
assert x.shape == (
|
|
N,
|
|
layer3_channels,
|
|
(((T - 2) - 3) // 2 + 1) - 2,
|
|
(((num_features - 3) // 2 + 1) - 3) // 2 + 1,
|
|
)
|
|
# (2, 128, 96, 19)
|
|
|
|
x = encoder_embed.conv[8](x) # balancer
|
|
x = encoder_embed.conv[9](x) # swooshR
|
|
|
|
# (((T - 2) - 3) // 2 + 1) - 2
|
|
# = (T - 2) - 3) // 2 + 1 - 2
|
|
# = ((T - 2) - 3) // 2 - 1
|
|
# = (T - 2 - 3) // 2 - 1
|
|
# = (T - 5) // 2 - 1
|
|
# = (T - 7) // 2
|
|
assert x.shape[2] == (x_copy.shape[1] - 7) // 2
|
|
|
|
# (((num_features - 3) // 2 + 1) - 3) // 2 + 1,
|
|
# = ((num_features - 3) // 2 + 1 - 3) // 2 + 1,
|
|
# = ((num_features - 3) // 2 - 2) // 2 + 1,
|
|
# = (num_features - 3 - 4) // 2 // 2 + 1,
|
|
# = (num_features - 7) // 2 // 2 + 1,
|
|
# = (num_features - 7) // 4 + 1,
|
|
# = (num_features - 3) // 4
|
|
assert x.shape[3] == (x_copy.shape[2] - 3) // 4
|
|
|
|
assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4)
|
|
|
|
# Input shape to convnext is
|
|
#
|
|
# (N, layer3_channels, (T-7)//2, (num_features - 3)//4)
|
|
|
|
# conv2d: in layer3_channels, out layer3_channels, groups layer3_channels
|
|
# kernel_size 7, padding 3
|
|
x = encoder_embed.convnext.depthwise_conv(x)
|
|
assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4)
|
|
|
|
# conv2d: in layer3_channels, out hidden_ratio * layer3_channels, kernel_size 1
|
|
x = encoder_embed.convnext.pointwise_conv1(x)
|
|
assert x.shape == (N, layer3_channels * 3, (T - 7) // 2, (num_features - 3) // 4)
|
|
|
|
x = encoder_embed.convnext.hidden_balancer(x) # balancer
|
|
x = encoder_embed.convnext.activation(x) # swooshL
|
|
|
|
# conv2d: in hidden_ratio * layer3_channels, out layer3_channels, kernel 1
|
|
x = encoder_embed.convnext.pointwise_conv2(x)
|
|
assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4)
|
|
|
|
# bypass and layer drop, omitted here.
|
|
x = encoder_embed.convnext.out_balancer(x)
|
|
|
|
# Note: the input and output shape of ConvNeXt are the same
|
|
|
|
x = x.transpose(1, 2).reshape(N, (T - 7) // 2, -1)
|
|
assert x.shape == (N, (T - 7) // 2, layer3_channels * ((num_features - 3) // 4))
|
|
|
|
x = encoder_embed.out(x)
|
|
assert x.shape == (N, (T - 7) // 2, out_channels)
|
|
|
|
x = encoder_embed.out_whiten(x)
|
|
x = encoder_embed.out_norm(x)
|
|
# final layer is dropout
|
|
|
|
# test streaming forward
|
|
|
|
subsampling_factor = 2
|
|
cached_left_padding = encoder_embed.get_init_states(batch_size=N)
|
|
depthwise_conv_kernel_size = 7
|
|
pad_size = (depthwise_conv_kernel_size - 1) // 2
|
|
|
|
assert cached_left_padding.shape == (
|
|
N,
|
|
layer3_channels,
|
|
pad_size,
|
|
(num_features - 3) // 4,
|
|
)
|
|
|
|
chunk_size = 16
|
|
right_padding = pad_size * subsampling_factor
|
|
T = chunk_size * subsampling_factor + 7 + right_padding
|
|
x = torch.rand(N, T, num_features)
|
|
x_lens = torch.tensor([T] * N)
|
|
y, y_lens, next_cached_left_padding = encoder_embed.streaming_forward(
|
|
x, x_lens, cached_left_padding
|
|
)
|
|
|
|
assert y.shape == (N, chunk_size, out_channels), y.shape
|
|
assert next_cached_left_padding.shape == cached_left_padding.shape
|
|
|
|
assert y.shape[1] == y_lens[0] == y_lens[1]
|
|
|
|
|
|
def main():
|
|
test_conv2d_subsampling()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|