mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add tests for subsample.py and fix typos (#1180)
This commit is contained in:
parent
4ab7d61008
commit
1dbbd7759e
57
.github/workflows/test.yml
vendored
57
.github/workflows/test.yml
vendored
@ -35,9 +35,9 @@ jobs:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: ["3.8"]
|
||||
torch: ["1.10.0"]
|
||||
torchaudio: ["0.10.0"]
|
||||
k2-version: ["1.23.2.dev20221201"]
|
||||
torch: ["1.13.0"]
|
||||
torchaudio: ["0.13.0"]
|
||||
k2-version: ["1.24.3.dev20230719"]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
@ -66,14 +66,14 @@ jobs:
|
||||
pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
|
||||
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
|
||||
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.github.io/k2/cpu.html
|
||||
pip install git+https://github.com/lhotse-speech/lhotse
|
||||
# icefall requirements
|
||||
pip uninstall -y protobuf
|
||||
pip install --no-binary protobuf protobuf==3.20.*
|
||||
|
||||
pip install kaldifst
|
||||
pip install onnxruntime
|
||||
pip install onnxruntime matplotlib
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Install graphviz
|
||||
@ -83,13 +83,6 @@ jobs:
|
||||
python3 -m pip install -qq graphviz
|
||||
sudo apt-get -qq install graphviz
|
||||
|
||||
- name: Install graphviz
|
||||
if: startsWith(matrix.os, 'macos')
|
||||
shell: bash
|
||||
run: |
|
||||
python3 -m pip install -qq graphviz
|
||||
brew install -q graphviz
|
||||
|
||||
- name: Run tests
|
||||
if: startsWith(matrix.os, 'ubuntu')
|
||||
run: |
|
||||
@ -129,40 +122,10 @@ jobs:
|
||||
cd ../transducer_lstm
|
||||
pytest -v -s
|
||||
|
||||
- name: Run tests
|
||||
if: startsWith(matrix.os, 'macos')
|
||||
run: |
|
||||
ls -lh
|
||||
export PYTHONPATH=$PWD:$PWD/lhotse:$PYTHONPATH
|
||||
lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
|
||||
echo "lib_path: $lib_path"
|
||||
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
|
||||
pytest -v -s ./test
|
||||
|
||||
# run tests for conformer ctc
|
||||
cd egs/librispeech/ASR/conformer_ctc
|
||||
cd ../zipformer
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless2
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless3
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless4
|
||||
pytest -v -s
|
||||
|
||||
cd ../transducer_stateless
|
||||
pytest -v -s
|
||||
|
||||
# cd ../transducer
|
||||
# pytest -v -s
|
||||
|
||||
cd ../transducer_stateless2
|
||||
pytest -v -s
|
||||
|
||||
cd ../transducer_lstm
|
||||
pytest -v -s
|
||||
- uses: actions/upload-artifact@v2
|
||||
with:
|
||||
path: egs/librispeech/ASR/zipformer/swoosh.pdf
|
||||
name: swoosh.pdf
|
||||
|
@ -849,6 +849,8 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
|
||||
"""
|
||||
if isinstance(left_context, torch.Tensor):
|
||||
left_context = left_context.item()
|
||||
self.extend_pe(x, left_context)
|
||||
x_size_1 = x.size(1) + left_context
|
||||
pos_emb = self.pe[
|
||||
|
@ -113,7 +113,7 @@ def test_rel_pos():
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_pos,
|
||||
x,
|
||||
(x, torch.zeros(1, dtype=torch.int64)),
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
@ -139,7 +139,9 @@ def test_rel_pos():
|
||||
assert input_nodes[0].name == "x"
|
||||
assert input_nodes[0].shape == ["N", "T", num_features]
|
||||
|
||||
inputs = {input_nodes[0].name: x.numpy()}
|
||||
inputs = {
|
||||
input_nodes[0].name: x.numpy(),
|
||||
}
|
||||
onnx_y, onnx_pos_emb = session.run(["y", "pos_emb"], inputs)
|
||||
onnx_y = torch.from_numpy(onnx_y)
|
||||
onnx_pos_emb = torch.from_numpy(onnx_pos_emb)
|
||||
|
@ -265,7 +265,7 @@ def test_zipformer_encoder():
|
||||
|
||||
torch.onnx.export(
|
||||
encoder,
|
||||
(x),
|
||||
(x, torch.ones(1, dtype=torch.float32)),
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
@ -289,6 +289,7 @@ def test_zipformer_encoder():
|
||||
input_nodes = session.get_inputs()
|
||||
inputs = {
|
||||
input_nodes[0].name: x.numpy(),
|
||||
input_nodes[1].name: torch.ones(1, dtype=torch.float32).numpy(),
|
||||
}
|
||||
onnx_y = session.run(["y"], inputs)[0]
|
||||
onnx_y = torch.from_numpy(onnx_y)
|
||||
|
1
egs/librispeech/ASR/zipformer/.gitignore
vendored
Normal file
1
egs/librispeech/ASR/zipformer/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
swoosh.pdf
|
@ -320,7 +320,7 @@ class AsrModel(nn.Module):
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
|
||||
|
||||
# Compute encoder outputs
|
||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||
|
@ -125,7 +125,7 @@ class PiecewiseLinear(object):
|
||||
p: 'PiecewiseLinear',
|
||||
include_crossings: bool = False):
|
||||
"""
|
||||
Returns (self_mod, p_mod) which are equivalent piecewise lienar
|
||||
Returns (self_mod, p_mod) which are equivalent piecewise linear
|
||||
functions to self and p, but with the same x values.
|
||||
|
||||
p: the other piecewise linear function
|
||||
@ -166,7 +166,7 @@ class ScheduledFloat(torch.nn.Module):
|
||||
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
||||
|
||||
It is a floating point value whose value changes depending on the batch count of the
|
||||
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
|
||||
training loop. It is a piecewise linear function where you specify the (x,y) pairs
|
||||
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
||||
first x or after the last x, we just use the first or last y value.
|
||||
|
||||
@ -343,7 +343,7 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
||||
class BiasNormFunction(torch.autograd.Function):
|
||||
# This computes:
|
||||
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
|
||||
# return (x - bias) * scales
|
||||
# return x * scales
|
||||
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
|
||||
# it can just store the returned value (chances are, this will also be needed for
|
||||
# some other reason, related to the next operation, so we can save memory).
|
||||
@ -400,8 +400,8 @@ class BiasNorm(torch.nn.Module):
|
||||
Args:
|
||||
num_channels: the number of channels, e.g. 512.
|
||||
channel_dim: the axis/dimension corresponding to the channel,
|
||||
interprted as an offset from the input's ndim if negative.
|
||||
shis is NOT the num_channels; it should typically be one of
|
||||
interpreted as an offset from the input's ndim if negative.
|
||||
This is NOT the num_channels; it should typically be one of
|
||||
{-2, -1, 0, 1, 2, 3}.
|
||||
log_scale: the initial log-scale that we multiply the output by; this
|
||||
is learnable.
|
||||
@ -1286,7 +1286,7 @@ class Dropout3(nn.Module):
|
||||
|
||||
class SwooshLFunction(torch.autograd.Function):
|
||||
"""
|
||||
swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
|
||||
swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@ -1361,7 +1361,7 @@ class SwooshLOnnx(torch.nn.Module):
|
||||
|
||||
class SwooshRFunction(torch.autograd.Function):
|
||||
"""
|
||||
swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
|
||||
swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
|
||||
|
||||
derivatives are between -0.08 and 0.92.
|
||||
"""
|
||||
|
@ -138,9 +138,11 @@ class ConvNeXt(nn.Module):
|
||||
|
||||
x = bypass + x
|
||||
x = self.out_balancer(x)
|
||||
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
|
||||
x = self.out_whiten(x)
|
||||
x = x.transpose(1, 3) # (N, C, H, W)
|
||||
|
||||
if x.requires_grad:
|
||||
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
|
||||
x = self.out_whiten(x)
|
||||
x = x.transpose(1, 3) # (N, C, H, W)
|
||||
|
||||
return x
|
||||
|
||||
@ -266,6 +268,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
# just one convnext layer
|
||||
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
|
||||
|
||||
# (in_channels-3)//4
|
||||
self.out_width = (((in_channels - 1) // 2) - 1) // 2
|
||||
self.layer3_channels = layer3_channels
|
||||
|
||||
@ -299,7 +302,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
|
||||
Returns:
|
||||
- a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||
- a tensor of shape (N, (T-7)//2, odim)
|
||||
- output lengths, of shape (batch_size,)
|
||||
"""
|
||||
# On entry, x is (N, T, idim)
|
||||
@ -310,14 +313,14 @@ class Conv2dSubsampling(nn.Module):
|
||||
x = self.conv(x)
|
||||
x = self.convnext(x)
|
||||
|
||||
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
# Now x is of shape (N, odim, (T-7)//2, (idim-3)//4)
|
||||
b, c, t, f = x.size()
|
||||
|
||||
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||
# now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels))
|
||||
# now x: (N, (T-7)//2, out_width * layer3_channels))
|
||||
|
||||
x = self.out(x)
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
# Now x is of shape (N, (T-7)//2, odim)
|
||||
x = self.out_whiten(x)
|
||||
x = self.out_norm(x)
|
||||
x = self.dropout(x)
|
||||
@ -328,7 +331,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
x_lens = (x_lens - 7) // 2
|
||||
assert x.size(1) == x_lens.max().item()
|
||||
assert x.size(1) == x_lens.max().item() , (x.size(1), x_lens.max())
|
||||
|
||||
return x, x_lens
|
||||
|
||||
@ -347,7 +350,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
|
||||
Returns:
|
||||
- a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||
- a tensor of shape (N, (T-7)//2, odim)
|
||||
- output lengths, of shape (batch_size,)
|
||||
- updated cache
|
||||
"""
|
||||
@ -383,7 +386,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
assert self.convnext.padding[0] == 3
|
||||
x_lens = (x_lens - 7) // 2 - 3
|
||||
|
||||
assert x.size(1) == x_lens.max().item()
|
||||
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max())
|
||||
|
||||
return x, x_lens, cached_left_pad
|
||||
|
||||
|
82
egs/librispeech/ASR/zipformer/test_scaling.py
Executable file
82
egs/librispeech/ASR/zipformer/test_scaling.py
Executable file
@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from scaling import PiecewiseLinear, ScheduledFloat, SwooshL, SwooshR
|
||||
|
||||
|
||||
def test_piecewise_linear():
|
||||
# An identity map in the range [0, 1].
|
||||
# 1 - identity map in the range [1, 2]
|
||||
# x1=0, y1=0
|
||||
# x2=1, y2=1
|
||||
# x3=2, y3=0
|
||||
pl = PiecewiseLinear((0, 0), (1, 1), (2, 0))
|
||||
assert pl(0.25) == 0.25, pl(0.25)
|
||||
assert pl(0.625) == 0.625, pl(0.625)
|
||||
assert pl(1.25) == 0.75, pl(1.25)
|
||||
|
||||
assert pl(-10) == pl(0), pl(-10) # out of range
|
||||
assert pl(10) == pl(2), pl(10) # out of range
|
||||
|
||||
# multiplication
|
||||
pl10 = pl * 10
|
||||
assert pl10(1) == 10 * pl(1)
|
||||
assert pl10(0.5) == 10 * pl(0.5)
|
||||
|
||||
|
||||
def test_scheduled_float():
|
||||
# Initial value is 0.2 and it decreases linearly towards 0 at 4000
|
||||
dropout = ScheduledFloat((0, 0.2), (4000, 0.0), default=0.0)
|
||||
dropout.batch_count = 0
|
||||
assert float(dropout) == 0.2, (float(dropout), dropout.batch_count)
|
||||
|
||||
dropout.batch_count = 1000
|
||||
assert abs(float(dropout) - 0.15) < 1e-5, (float(dropout), dropout.batch_count)
|
||||
|
||||
dropout.batch_count = 2000
|
||||
assert float(dropout) == 0.1, (float(dropout), dropout.batch_count)
|
||||
|
||||
dropout.batch_count = 3000
|
||||
assert abs(float(dropout) - 0.05) < 1e-5, (float(dropout), dropout.batch_count)
|
||||
|
||||
dropout.batch_count = 4000
|
||||
assert float(dropout) == 0.0, (float(dropout), dropout.batch_count)
|
||||
|
||||
dropout.batch_count = 5000 # out of range
|
||||
assert float(dropout) == 0.0, (float(dropout), dropout.batch_count)
|
||||
|
||||
|
||||
def test_swoosh():
|
||||
x1 = torch.linspace(start=-10, end=0, steps=100, dtype=torch.float32)
|
||||
x2 = torch.linspace(start=0, end=10, steps=100, dtype=torch.float32)
|
||||
x = torch.cat([x1, x2[1:]])
|
||||
|
||||
left = SwooshL()(x)
|
||||
r = SwooshR()(x)
|
||||
|
||||
relu = torch.nn.functional.relu(x)
|
||||
print(left[x == 0], r[x == 0])
|
||||
plt.plot(x, left, "k")
|
||||
plt.plot(x, r, "r")
|
||||
plt.plot(x, relu, "b")
|
||||
plt.axis([-10, 10, -1, 10]) # [xmin, xmax, ymin, ymax]
|
||||
plt.legend(
|
||||
[
|
||||
"SwooshL(x) = log(1 + exp(x-4)) - 0.08x - 0.035 ",
|
||||
"SwooshR(x) = log(1 + exp(x-1)) - 0.08x - 0.313261687",
|
||||
"ReLU(x) = max(0, x)",
|
||||
]
|
||||
)
|
||||
plt.grid()
|
||||
plt.savefig("swoosh.pdf")
|
||||
|
||||
|
||||
def main():
|
||||
test_piecewise_linear()
|
||||
test_scheduled_float()
|
||||
test_swoosh()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
152
egs/librispeech/ASR/zipformer/test_subsampling.py
Executable file
152
egs/librispeech/ASR/zipformer/test_subsampling.py
Executable file
@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
from scaling import ScheduledFloat
|
||||
from subsampling import Conv2dSubsampling
|
||||
|
||||
|
||||
def test_conv2d_subsampling():
|
||||
layer1_channels = 8
|
||||
layer2_channels = 32
|
||||
layer3_channels = 128
|
||||
|
||||
out_channels = 192
|
||||
encoder_embed = Conv2dSubsampling(
|
||||
in_channels=80,
|
||||
out_channels=out_channels,
|
||||
layer1_channels=layer1_channels,
|
||||
layer2_channels=layer2_channels,
|
||||
layer3_channels=layer3_channels,
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
)
|
||||
N = 2
|
||||
T = 200
|
||||
num_features = 80
|
||||
x = torch.rand(N, T, num_features)
|
||||
x_copy = x.clone()
|
||||
|
||||
x = x.unsqueeze(1) # (N, 1, T, num_features)
|
||||
|
||||
x = encoder_embed.conv[0](x) # conv2d, in 1, out 8, kernel 3, padding (0,1)
|
||||
assert x.shape == (N, layer1_channels, T - 2, num_features)
|
||||
# (2, 8, 198, 80)
|
||||
|
||||
x = encoder_embed.conv[1](x) # scale grad
|
||||
x = encoder_embed.conv[2](x) # balancer
|
||||
x = encoder_embed.conv[3](x) # swooshR
|
||||
|
||||
x = encoder_embed.conv[4](x) # conv2d, in 8, out 32, kernel 3, stride 2
|
||||
assert x.shape == (
|
||||
N,
|
||||
layer2_channels,
|
||||
((T - 2) - 3) // 2 + 1,
|
||||
(num_features - 3) // 2 + 1,
|
||||
)
|
||||
# (2, 32, 98, 39)
|
||||
|
||||
x = encoder_embed.conv[5](x) # balancer
|
||||
x = encoder_embed.conv[6](x) # swooshR
|
||||
|
||||
# conv2d:
|
||||
# in 32, out 128, kernel 3, stride (1, 2)
|
||||
x = encoder_embed.conv[7](x)
|
||||
assert x.shape == (
|
||||
N,
|
||||
layer3_channels,
|
||||
(((T - 2) - 3) // 2 + 1) - 2,
|
||||
(((num_features - 3) // 2 + 1) - 3) // 2 + 1,
|
||||
)
|
||||
# (2, 128, 96, 19)
|
||||
|
||||
x = encoder_embed.conv[8](x) # balancer
|
||||
x = encoder_embed.conv[9](x) # swooshR
|
||||
|
||||
# (((T - 2) - 3) // 2 + 1) - 2
|
||||
# = (T - 2) - 3) // 2 + 1 - 2
|
||||
# = ((T - 2) - 3) // 2 - 1
|
||||
# = (T - 2 - 3) // 2 - 1
|
||||
# = (T - 5) // 2 - 1
|
||||
# = (T - 7) // 2
|
||||
assert x.shape[2] == (x_copy.shape[1] - 7) // 2
|
||||
|
||||
# (((num_features - 3) // 2 + 1) - 3) // 2 + 1,
|
||||
# = ((num_features - 3) // 2 + 1 - 3) // 2 + 1,
|
||||
# = ((num_features - 3) // 2 - 2) // 2 + 1,
|
||||
# = (num_features - 3 - 4) // 2 // 2 + 1,
|
||||
# = (num_features - 7) // 2 // 2 + 1,
|
||||
# = (num_features - 7) // 4 + 1,
|
||||
# = (num_features - 3) // 4
|
||||
assert x.shape[3] == (x_copy.shape[2] - 3) // 4
|
||||
|
||||
assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4)
|
||||
|
||||
# Input shape to convnext is
|
||||
#
|
||||
# (N, layer3_channels, (T-7)//2, (num_features - 3)//4)
|
||||
|
||||
# conv2d: in layer3_channels, out layer3_channels, groups layer3_channels
|
||||
# kernel_size 7, padding 3
|
||||
x = encoder_embed.convnext.depthwise_conv(x)
|
||||
assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4)
|
||||
|
||||
# conv2d: in layer3_channels, out hidden_ratio * layer3_channels, kernel_size 1
|
||||
x = encoder_embed.convnext.pointwise_conv1(x)
|
||||
assert x.shape == (N, layer3_channels * 3, (T - 7) // 2, (num_features - 3) // 4)
|
||||
|
||||
x = encoder_embed.convnext.hidden_balancer(x) # balancer
|
||||
x = encoder_embed.convnext.activation(x) # swooshL
|
||||
|
||||
# conv2d: in hidden_ratio * layer3_channels, out layer3_channels, kernel 1
|
||||
x = encoder_embed.convnext.pointwise_conv2(x)
|
||||
assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4)
|
||||
|
||||
# bypass and layer drop, omitted here.
|
||||
x = encoder_embed.convnext.out_balancer(x)
|
||||
|
||||
# Note: the input and output shape of ConvNeXt are the same
|
||||
|
||||
x = x.transpose(1, 2).reshape(N, (T - 7) // 2, -1)
|
||||
assert x.shape == (N, (T - 7) // 2, layer3_channels * ((num_features - 3) // 4))
|
||||
|
||||
x = encoder_embed.out(x)
|
||||
assert x.shape == (N, (T - 7) // 2, out_channels)
|
||||
|
||||
x = encoder_embed.out_whiten(x)
|
||||
x = encoder_embed.out_norm(x)
|
||||
# final layer is dropout
|
||||
|
||||
# test streaming forward
|
||||
|
||||
subsampling_factor = 2
|
||||
cached_left_padding = encoder_embed.get_init_states(batch_size=N)
|
||||
depthwise_conv_kernel_size = 7
|
||||
pad_size = (depthwise_conv_kernel_size - 1) // 2
|
||||
|
||||
assert cached_left_padding.shape == (
|
||||
N,
|
||||
layer3_channels,
|
||||
pad_size,
|
||||
(num_features - 3) // 4,
|
||||
)
|
||||
|
||||
chunk_size = 16
|
||||
right_padding = pad_size * subsampling_factor
|
||||
T = chunk_size * subsampling_factor + 7 + right_padding
|
||||
x = torch.rand(N, T, num_features)
|
||||
x_lens = torch.tensor([T] * N)
|
||||
y, y_lens, next_cached_left_padding = encoder_embed.streaming_forward(
|
||||
x, x_lens, cached_left_padding
|
||||
)
|
||||
|
||||
assert y.shape == (N, chunk_size, out_channels), y.shape
|
||||
assert next_cached_left_padding.shape == cached_left_padding.shape
|
||||
|
||||
assert y.shape[1] == y_lens[0] == y_lens[1]
|
||||
|
||||
|
||||
def main():
|
||||
test_conv2d_subsampling()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -219,7 +219,7 @@ class Zipformer2(EncoderInterface):
|
||||
|
||||
(num_frames0, batch_size, _encoder_dims0) = x.shape
|
||||
|
||||
assert self.encoder_dim[0] == _encoder_dims0
|
||||
assert self.encoder_dim[0] == _encoder_dims0, (self.encoder_dim[0], _encoder_dims0)
|
||||
|
||||
feature_mask_dropout_prob = 0.125
|
||||
|
||||
@ -334,7 +334,7 @@ class Zipformer2(EncoderInterface):
|
||||
x = self._get_full_dim_output(outputs)
|
||||
x = self.downsample_output(x)
|
||||
# class Downsample has this rounding behavior..
|
||||
assert self.output_downsampling_factor == 2
|
||||
assert self.output_downsampling_factor == 2, self.output_downsampling_factor
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
lengths = (x_lens + 1) // 2
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user