mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
add tests for subsampling.py and fix typos
This commit is contained in:
parent
4ab7d61008
commit
d5dcca674c
1
egs/librispeech/ASR/zipformer/.gitignore
vendored
Normal file
1
egs/librispeech/ASR/zipformer/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
swoosh.pdf
|
||||||
@ -125,7 +125,7 @@ class PiecewiseLinear(object):
|
|||||||
p: 'PiecewiseLinear',
|
p: 'PiecewiseLinear',
|
||||||
include_crossings: bool = False):
|
include_crossings: bool = False):
|
||||||
"""
|
"""
|
||||||
Returns (self_mod, p_mod) which are equivalent piecewise lienar
|
Returns (self_mod, p_mod) which are equivalent piecewise linear
|
||||||
functions to self and p, but with the same x values.
|
functions to self and p, but with the same x values.
|
||||||
|
|
||||||
p: the other piecewise linear function
|
p: the other piecewise linear function
|
||||||
@ -166,7 +166,7 @@ class ScheduledFloat(torch.nn.Module):
|
|||||||
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
||||||
|
|
||||||
It is a floating point value whose value changes depending on the batch count of the
|
It is a floating point value whose value changes depending on the batch count of the
|
||||||
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
|
training loop. It is a piecewise linear function where you specify the (x,y) pairs
|
||||||
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
||||||
first x or after the last x, we just use the first or last y value.
|
first x or after the last x, we just use the first or last y value.
|
||||||
|
|
||||||
@ -343,7 +343,7 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
|||||||
class BiasNormFunction(torch.autograd.Function):
|
class BiasNormFunction(torch.autograd.Function):
|
||||||
# This computes:
|
# This computes:
|
||||||
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
|
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
|
||||||
# return (x - bias) * scales
|
# return x * scales
|
||||||
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
|
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
|
||||||
# it can just store the returned value (chances are, this will also be needed for
|
# it can just store the returned value (chances are, this will also be needed for
|
||||||
# some other reason, related to the next operation, so we can save memory).
|
# some other reason, related to the next operation, so we can save memory).
|
||||||
@ -400,8 +400,8 @@ class BiasNorm(torch.nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
num_channels: the number of channels, e.g. 512.
|
num_channels: the number of channels, e.g. 512.
|
||||||
channel_dim: the axis/dimension corresponding to the channel,
|
channel_dim: the axis/dimension corresponding to the channel,
|
||||||
interprted as an offset from the input's ndim if negative.
|
interpreted as an offset from the input's ndim if negative.
|
||||||
shis is NOT the num_channels; it should typically be one of
|
This is NOT the num_channels; it should typically be one of
|
||||||
{-2, -1, 0, 1, 2, 3}.
|
{-2, -1, 0, 1, 2, 3}.
|
||||||
log_scale: the initial log-scale that we multiply the output by; this
|
log_scale: the initial log-scale that we multiply the output by; this
|
||||||
is learnable.
|
is learnable.
|
||||||
@ -1286,7 +1286,7 @@ class Dropout3(nn.Module):
|
|||||||
|
|
||||||
class SwooshLFunction(torch.autograd.Function):
|
class SwooshLFunction(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
|
swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1361,7 +1361,7 @@ class SwooshLOnnx(torch.nn.Module):
|
|||||||
|
|
||||||
class SwooshRFunction(torch.autograd.Function):
|
class SwooshRFunction(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
|
swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
|
||||||
|
|
||||||
derivatives are between -0.08 and 0.92.
|
derivatives are between -0.08 and 0.92.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -138,9 +138,11 @@ class ConvNeXt(nn.Module):
|
|||||||
|
|
||||||
x = bypass + x
|
x = bypass + x
|
||||||
x = self.out_balancer(x)
|
x = self.out_balancer(x)
|
||||||
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
|
|
||||||
x = self.out_whiten(x)
|
if x.requires_grad:
|
||||||
x = x.transpose(1, 3) # (N, C, H, W)
|
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
|
||||||
|
x = self.out_whiten(x)
|
||||||
|
x = x.transpose(1, 3) # (N, C, H, W)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -266,6 +268,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
# just one convnext layer
|
# just one convnext layer
|
||||||
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
|
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
|
||||||
|
|
||||||
|
# (in_channels-3)//4
|
||||||
self.out_width = (((in_channels - 1) // 2) - 1) // 2
|
self.out_width = (((in_channels - 1) // 2) - 1) // 2
|
||||||
self.layer3_channels = layer3_channels
|
self.layer3_channels = layer3_channels
|
||||||
|
|
||||||
@ -299,7 +302,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
A tensor of shape (batch_size,) containing the number of frames in
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
- a tensor of shape (N, (T-7)//2, odim)
|
||||||
- output lengths, of shape (batch_size,)
|
- output lengths, of shape (batch_size,)
|
||||||
"""
|
"""
|
||||||
# On entry, x is (N, T, idim)
|
# On entry, x is (N, T, idim)
|
||||||
@ -310,14 +313,14 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
x = self.convnext(x)
|
x = self.convnext(x)
|
||||||
|
|
||||||
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
# Now x is of shape (N, odim, (T-7)//2, (idim-3)//4)
|
||||||
b, c, t, f = x.size()
|
b, c, t, f = x.size()
|
||||||
|
|
||||||
x = x.transpose(1, 2).reshape(b, t, c * f)
|
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||||
# now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels))
|
# now x: (N, (T-7)//2, out_width * layer3_channels))
|
||||||
|
|
||||||
x = self.out(x)
|
x = self.out(x)
|
||||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
# Now x is of shape (N, (T-7)//2, odim)
|
||||||
x = self.out_whiten(x)
|
x = self.out_whiten(x)
|
||||||
x = self.out_norm(x)
|
x = self.out_norm(x)
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
@ -328,7 +331,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
x_lens = (x_lens - 7) // 2
|
x_lens = (x_lens - 7) // 2
|
||||||
assert x.size(1) == x_lens.max().item()
|
assert x.size(1) == x_lens.max().item() , (x.size(1), x_lens.max())
|
||||||
|
|
||||||
return x, x_lens
|
return x, x_lens
|
||||||
|
|
||||||
@ -347,7 +350,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
A tensor of shape (batch_size,) containing the number of frames in
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
- a tensor of shape (N, (T-7)//2, odim)
|
||||||
- output lengths, of shape (batch_size,)
|
- output lengths, of shape (batch_size,)
|
||||||
- updated cache
|
- updated cache
|
||||||
"""
|
"""
|
||||||
@ -383,7 +386,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
assert self.convnext.padding[0] == 3
|
assert self.convnext.padding[0] == 3
|
||||||
x_lens = (x_lens - 7) // 2 - 3
|
x_lens = (x_lens - 7) // 2 - 3
|
||||||
|
|
||||||
assert x.size(1) == x_lens.max().item()
|
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max())
|
||||||
|
|
||||||
return x, x_lens, cached_left_pad
|
return x, x_lens, cached_left_pad
|
||||||
|
|
||||||
|
|||||||
82
egs/librispeech/ASR/zipformer/test_scaling.py
Executable file
82
egs/librispeech/ASR/zipformer/test_scaling.py
Executable file
@ -0,0 +1,82 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
from scaling import PiecewiseLinear, ScheduledFloat, SwooshL, SwooshR
|
||||||
|
|
||||||
|
|
||||||
|
def test_piecewise_linear():
|
||||||
|
# An identity map in the range [0, 1].
|
||||||
|
# 1 - identity map in the range [1, 2]
|
||||||
|
# x1=0, y1=0
|
||||||
|
# x2=1, y2=1
|
||||||
|
# x3=2, y3=0
|
||||||
|
pl = PiecewiseLinear((0, 0), (1, 1), (2, 0))
|
||||||
|
assert pl(0.25) == 0.25, pl(0.25)
|
||||||
|
assert pl(0.625) == 0.625, pl(0.625)
|
||||||
|
assert pl(1.25) == 0.75, pl(1.25)
|
||||||
|
|
||||||
|
assert pl(-10) == pl(0), pl(-10) # out of range
|
||||||
|
assert pl(10) == pl(2), pl(10) # out of range
|
||||||
|
|
||||||
|
# multiplication
|
||||||
|
pl10 = pl * 10
|
||||||
|
assert pl10(1) == 10 * pl(1)
|
||||||
|
assert pl10(0.5) == 10 * pl(0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduled_float():
|
||||||
|
# Initial value is 0.2 and it decreases linearly towards 0 at 4000
|
||||||
|
dropout = ScheduledFloat((0, 0.2), (4000, 0.0), default=0.0)
|
||||||
|
dropout.batch_count = 0
|
||||||
|
assert float(dropout) == 0.2, (float(dropout), dropout.batch_count)
|
||||||
|
|
||||||
|
dropout.batch_count = 1000
|
||||||
|
assert abs(float(dropout) - 0.15) < 1e-5, (float(dropout), dropout.batch_count)
|
||||||
|
|
||||||
|
dropout.batch_count = 2000
|
||||||
|
assert float(dropout) == 0.1, (float(dropout), dropout.batch_count)
|
||||||
|
|
||||||
|
dropout.batch_count = 3000
|
||||||
|
assert abs(float(dropout) - 0.05) < 1e-5, (float(dropout), dropout.batch_count)
|
||||||
|
|
||||||
|
dropout.batch_count = 4000
|
||||||
|
assert float(dropout) == 0.0, (float(dropout), dropout.batch_count)
|
||||||
|
|
||||||
|
dropout.batch_count = 5000 # out of range
|
||||||
|
assert float(dropout) == 0.0, (float(dropout), dropout.batch_count)
|
||||||
|
|
||||||
|
|
||||||
|
def test_swoosh():
|
||||||
|
x1 = torch.linspace(start=-10, end=0, steps=100, dtype=torch.float32)
|
||||||
|
x2 = torch.linspace(start=0, end=10, steps=100, dtype=torch.float32)
|
||||||
|
x = torch.cat([x1, x2[1:]])
|
||||||
|
|
||||||
|
left = SwooshL()(x)
|
||||||
|
r = SwooshR()(x)
|
||||||
|
|
||||||
|
relu = torch.nn.functional.relu(x)
|
||||||
|
print(left[x == 0], r[x == 0])
|
||||||
|
plt.plot(x, left, "k")
|
||||||
|
plt.plot(x, r, "r")
|
||||||
|
plt.plot(x, relu, "b")
|
||||||
|
plt.axis([-10, 10, -1, 10]) # [xmin, xmax, ymin, ymax]
|
||||||
|
plt.legend(
|
||||||
|
[
|
||||||
|
"SwooshL(x) = log(1 + exp(x-4)) - 0.08x - 0.035 ",
|
||||||
|
"SwooshR(x) = log(1 + exp(x-1)) - 0.08x - 0.313261687",
|
||||||
|
"ReLU(x) = max(0, x)",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
plt.grid()
|
||||||
|
plt.savefig("swoosh.pdf")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_piecewise_linear()
|
||||||
|
test_scheduled_float()
|
||||||
|
test_swoosh()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
152
egs/librispeech/ASR/zipformer/test_subsampling.py
Executable file
152
egs/librispeech/ASR/zipformer/test_subsampling.py
Executable file
@ -0,0 +1,152 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from scaling import ScheduledFloat
|
||||||
|
from subsampling import Conv2dSubsampling
|
||||||
|
|
||||||
|
|
||||||
|
def test_conv2d_subsampling():
|
||||||
|
layer1_channels = 8
|
||||||
|
layer2_channels = 32
|
||||||
|
layer3_channels = 128
|
||||||
|
|
||||||
|
out_channels = 192
|
||||||
|
encoder_embed = Conv2dSubsampling(
|
||||||
|
in_channels=80,
|
||||||
|
out_channels=out_channels,
|
||||||
|
layer1_channels=layer1_channels,
|
||||||
|
layer2_channels=layer2_channels,
|
||||||
|
layer3_channels=layer3_channels,
|
||||||
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||||
|
)
|
||||||
|
N = 2
|
||||||
|
T = 200
|
||||||
|
num_features = 80
|
||||||
|
x = torch.rand(N, T, num_features)
|
||||||
|
x_copy = x.clone()
|
||||||
|
|
||||||
|
x = x.unsqueeze(1) # (N, 1, T, num_features)
|
||||||
|
|
||||||
|
x = encoder_embed.conv[0](x) # conv2d, in 1, out 8, kernel 3, padding (0,1)
|
||||||
|
assert x.shape == (N, layer1_channels, T - 2, num_features)
|
||||||
|
# (2, 8, 198, 80)
|
||||||
|
|
||||||
|
x = encoder_embed.conv[1](x) # scale grad
|
||||||
|
x = encoder_embed.conv[2](x) # balancer
|
||||||
|
x = encoder_embed.conv[3](x) # swooshR
|
||||||
|
|
||||||
|
x = encoder_embed.conv[4](x) # conv2d, in 8, out 32, kernel 3, stride 2
|
||||||
|
assert x.shape == (
|
||||||
|
N,
|
||||||
|
layer2_channels,
|
||||||
|
((T - 2) - 3) // 2 + 1,
|
||||||
|
(num_features - 3) // 2 + 1,
|
||||||
|
)
|
||||||
|
# (2, 32, 98, 39)
|
||||||
|
|
||||||
|
x = encoder_embed.conv[5](x) # balancer
|
||||||
|
x = encoder_embed.conv[6](x) # swooshR
|
||||||
|
|
||||||
|
# conv2d:
|
||||||
|
# in 32, out 128, kernel 3, stride (1, 2)
|
||||||
|
x = encoder_embed.conv[7](x)
|
||||||
|
assert x.shape == (
|
||||||
|
N,
|
||||||
|
layer3_channels,
|
||||||
|
(((T - 2) - 3) // 2 + 1) - 2,
|
||||||
|
(((num_features - 3) // 2 + 1) - 3) // 2 + 1,
|
||||||
|
)
|
||||||
|
# (2, 128, 96, 19)
|
||||||
|
|
||||||
|
x = encoder_embed.conv[8](x) # balancer
|
||||||
|
x = encoder_embed.conv[9](x) # swooshR
|
||||||
|
|
||||||
|
# (((T - 2) - 3) // 2 + 1) - 2
|
||||||
|
# = (T - 2) - 3) // 2 + 1 - 2
|
||||||
|
# = ((T - 2) - 3) // 2 - 1
|
||||||
|
# = (T - 2 - 3) // 2 - 1
|
||||||
|
# = (T - 5) // 2 - 1
|
||||||
|
# = (T - 7) // 2
|
||||||
|
assert x.shape[2] == (x_copy.shape[1] - 7) // 2
|
||||||
|
|
||||||
|
# (((num_features - 3) // 2 + 1) - 3) // 2 + 1,
|
||||||
|
# = ((num_features - 3) // 2 + 1 - 3) // 2 + 1,
|
||||||
|
# = ((num_features - 3) // 2 - 2) // 2 + 1,
|
||||||
|
# = (num_features - 3 - 4) // 2 // 2 + 1,
|
||||||
|
# = (num_features - 7) // 2 // 2 + 1,
|
||||||
|
# = (num_features - 7) // 4 + 1,
|
||||||
|
# = (num_features - 3) // 4
|
||||||
|
assert x.shape[3] == (x_copy.shape[2] - 3) // 4
|
||||||
|
|
||||||
|
assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4)
|
||||||
|
|
||||||
|
# Input shape to convnext is
|
||||||
|
#
|
||||||
|
# (N, layer3_channels, (T-7)//2, (num_features - 3)//4)
|
||||||
|
|
||||||
|
# conv2d: in layer3_channels, out layer3_channels, groups layer3_channels
|
||||||
|
# kernel_size 7, padding 3
|
||||||
|
x = encoder_embed.convnext.depthwise_conv(x)
|
||||||
|
assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4)
|
||||||
|
|
||||||
|
# conv2d: in layer3_channels, out hidden_ratio * layer3_channels, kernel_size 1
|
||||||
|
x = encoder_embed.convnext.pointwise_conv1(x)
|
||||||
|
assert x.shape == (N, layer3_channels * 3, (T - 7) // 2, (num_features - 3) // 4)
|
||||||
|
|
||||||
|
x = encoder_embed.convnext.hidden_balancer(x) # balancer
|
||||||
|
x = encoder_embed.convnext.activation(x) # swooshL
|
||||||
|
|
||||||
|
# conv2d: in hidden_ratio * layer3_channels, out layer3_channels, kernel 1
|
||||||
|
x = encoder_embed.convnext.pointwise_conv2(x)
|
||||||
|
assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4)
|
||||||
|
|
||||||
|
# bypass and layer drop, omitted here.
|
||||||
|
x = encoder_embed.convnext.out_balancer(x)
|
||||||
|
|
||||||
|
# Note: the input and output shape of ConvNeXt are the same
|
||||||
|
|
||||||
|
x = x.transpose(1, 2).reshape(N, (T - 7) // 2, -1)
|
||||||
|
assert x.shape == (N, (T - 7) // 2, layer3_channels * ((num_features - 3) // 4))
|
||||||
|
|
||||||
|
x = encoder_embed.out(x)
|
||||||
|
assert x.shape == (N, (T - 7) // 2, out_channels)
|
||||||
|
|
||||||
|
x = encoder_embed.out_whiten(x)
|
||||||
|
x = encoder_embed.out_norm(x)
|
||||||
|
# final layer is dropout
|
||||||
|
|
||||||
|
# test streaming forward
|
||||||
|
|
||||||
|
subsampling_factor = 2
|
||||||
|
cached_left_padding = encoder_embed.get_init_states(batch_size=N)
|
||||||
|
depthwise_conv_kernel_size = 7
|
||||||
|
pad_size = (depthwise_conv_kernel_size - 1) // 2
|
||||||
|
|
||||||
|
assert cached_left_padding.shape == (
|
||||||
|
N,
|
||||||
|
layer3_channels,
|
||||||
|
pad_size,
|
||||||
|
(num_features - 3) // 4,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_size = 16
|
||||||
|
right_padding = pad_size * subsampling_factor
|
||||||
|
T = chunk_size * subsampling_factor + 7 + right_padding
|
||||||
|
x = torch.rand(N, T, num_features)
|
||||||
|
x_lens = torch.tensor([T] * N)
|
||||||
|
y, y_lens, next_cached_left_padding = encoder_embed.streaming_forward(
|
||||||
|
x, x_lens, cached_left_padding
|
||||||
|
)
|
||||||
|
|
||||||
|
assert y.shape == (N, chunk_size, out_channels), y.shape
|
||||||
|
assert next_cached_left_padding.shape == cached_left_padding.shape
|
||||||
|
|
||||||
|
assert y.shape[1] == y_lens[0] == y_lens[1]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_conv2d_subsampling()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user