diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e04fb5655..363556bb7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -35,9 +35,9 @@ jobs: matrix: os: [ubuntu-latest] python-version: ["3.8"] - torch: ["1.10.0"] - torchaudio: ["0.10.0"] - k2-version: ["1.23.2.dev20221201"] + torch: ["1.13.0"] + torchaudio: ["0.13.0"] + k2-version: ["1.24.3.dev20230719"] fail-fast: false @@ -66,14 +66,14 @@ jobs: pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html - pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ + pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.github.io/k2/cpu.html pip install git+https://github.com/lhotse-speech/lhotse # icefall requirements pip uninstall -y protobuf pip install --no-binary protobuf protobuf==3.20.* pip install kaldifst - pip install onnxruntime + pip install onnxruntime matplotlib pip install -r requirements.txt - name: Install graphviz @@ -83,13 +83,6 @@ jobs: python3 -m pip install -qq graphviz sudo apt-get -qq install graphviz - - name: Install graphviz - if: startsWith(matrix.os, 'macos') - shell: bash - run: | - python3 -m pip install -qq graphviz - brew install -q graphviz - - name: Run tests if: startsWith(matrix.os, 'ubuntu') run: | @@ -129,40 +122,10 @@ jobs: cd ../transducer_lstm pytest -v -s - - name: Run tests - if: startsWith(matrix.os, 'macos') - run: | - ls -lh - export PYTHONPATH=$PWD:$PWD/lhotse:$PYTHONPATH - lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())") - echo "lib_path: $lib_path" - export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH - pytest -v -s ./test - - # run tests for conformer ctc - cd egs/librispeech/ASR/conformer_ctc + cd ../zipformer pytest -v -s - cd ../pruned_transducer_stateless - pytest -v -s - - cd ../pruned_transducer_stateless2 - pytest -v -s - - cd ../pruned_transducer_stateless3 - pytest -v -s - - cd ../pruned_transducer_stateless4 - pytest -v -s - - cd ../transducer_stateless - pytest -v -s - - # cd ../transducer - # pytest -v -s - - cd ../transducer_stateless2 - pytest -v -s - - cd ../transducer_lstm - pytest -v -s + - uses: actions/upload-artifact@v2 + with: + path: egs/librispeech/ASR/zipformer/swoosh.pdf + name: swoosh.pdf diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 9bac46004..bcd419fb7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -849,6 +849,8 @@ class RelPositionalEncoding(torch.nn.Module): torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). """ + if isinstance(left_context, torch.Tensor): + left_context = left_context.item() self.extend_pe(x, left_context) x_size_1 = x.size(1) + left_context pos_emb = self.pe[ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index 598fcf344..810da8da6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -113,7 +113,7 @@ def test_rel_pos(): torch.onnx.export( encoder_pos, - x, + (x, torch.zeros(1, dtype=torch.int64)), filename, verbose=False, opset_version=opset_version, @@ -139,7 +139,9 @@ def test_rel_pos(): assert input_nodes[0].name == "x" assert input_nodes[0].shape == ["N", "T", num_features] - inputs = {input_nodes[0].name: x.numpy()} + inputs = { + input_nodes[0].name: x.numpy(), + } onnx_y, onnx_pos_emb = session.run(["y", "pos_emb"], inputs) onnx_y = torch.from_numpy(onnx_y) onnx_pos_emb = torch.from_numpy(onnx_pos_emb) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py index 2440d267c..1e9b67226 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py @@ -265,7 +265,7 @@ def test_zipformer_encoder(): torch.onnx.export( encoder, - (x), + (x, torch.ones(1, dtype=torch.float32)), filename, verbose=False, opset_version=opset_version, @@ -289,6 +289,7 @@ def test_zipformer_encoder(): input_nodes = session.get_inputs() inputs = { input_nodes[0].name: x.numpy(), + input_nodes[1].name: torch.ones(1, dtype=torch.float32).numpy(), } onnx_y = session.run(["y"], inputs)[0] onnx_y = torch.from_numpy(onnx_y) diff --git a/egs/librispeech/ASR/zipformer/.gitignore b/egs/librispeech/ASR/zipformer/.gitignore new file mode 100644 index 000000000..e47ac1582 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/.gitignore @@ -0,0 +1 @@ +swoosh.pdf diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index b541ee697..f2f86af47 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -320,7 +320,7 @@ class AsrModel(nn.Module): assert x_lens.ndim == 1, x_lens.shape assert y.num_axes == 2, y.num_axes - assert x.size(0) == x_lens.size(0) == y.dim0 + assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) # Compute encoder outputs encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 4ee7b7826..7c98ef045 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -125,7 +125,7 @@ class PiecewiseLinear(object): p: 'PiecewiseLinear', include_crossings: bool = False): """ - Returns (self_mod, p_mod) which are equivalent piecewise lienar + Returns (self_mod, p_mod) which are equivalent piecewise linear functions to self and p, but with the same x values. p: the other piecewise linear function @@ -166,7 +166,7 @@ class ScheduledFloat(torch.nn.Module): in, float(parent_module.whatever), and use it as something like a dropout prob. It is a floating point value whose value changes depending on the batch count of the - training loop. It is a piecewise linear function where you specifiy the (x,y) pairs + training loop. It is a piecewise linear function where you specify the (x,y) pairs in sorted order on x; x corresponds to the batch index. For batch-index values before the first x or after the last x, we just use the first or last y value. @@ -343,7 +343,7 @@ class MaxEigLimiterFunction(torch.autograd.Function): class BiasNormFunction(torch.autograd.Function): # This computes: # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() - # return (x - bias) * scales + # return x * scales # (after unsqueezing the bias), but it does it in a memory-efficient way so that # it can just store the returned value (chances are, this will also be needed for # some other reason, related to the next operation, so we can save memory). @@ -400,8 +400,8 @@ class BiasNorm(torch.nn.Module): Args: num_channels: the number of channels, e.g. 512. channel_dim: the axis/dimension corresponding to the channel, - interprted as an offset from the input's ndim if negative. - shis is NOT the num_channels; it should typically be one of + interpreted as an offset from the input's ndim if negative. + This is NOT the num_channels; it should typically be one of {-2, -1, 0, 1, 2, 3}. log_scale: the initial log-scale that we multiply the output by; this is learnable. @@ -1286,7 +1286,7 @@ class Dropout3(nn.Module): class SwooshLFunction(torch.autograd.Function): """ - swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 + swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 """ @staticmethod @@ -1361,7 +1361,7 @@ class SwooshLOnnx(torch.nn.Module): class SwooshRFunction(torch.autograd.Function): """ - swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 derivatives are between -0.08 and 0.92. """ diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index d6bf57db4..6532ddccb 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -138,9 +138,11 @@ class ConvNeXt(nn.Module): x = bypass + x x = self.out_balancer(x) - x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last - x = self.out_whiten(x) - x = x.transpose(1, 3) # (N, C, H, W) + + if x.requires_grad: + x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last + x = self.out_whiten(x) + x = x.transpose(1, 3) # (N, C, H, W) return x @@ -266,6 +268,7 @@ class Conv2dSubsampling(nn.Module): # just one convnext layer self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) + # (in_channels-3)//4 self.out_width = (((in_channels - 1) // 2) - 1) // 2 self.layer3_channels = layer3_channels @@ -299,7 +302,7 @@ class Conv2dSubsampling(nn.Module): A tensor of shape (batch_size,) containing the number of frames in Returns: - - a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + - a tensor of shape (N, (T-7)//2, odim) - output lengths, of shape (batch_size,) """ # On entry, x is (N, T, idim) @@ -310,14 +313,14 @@ class Conv2dSubsampling(nn.Module): x = self.conv(x) x = self.convnext(x) - # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) + # Now x is of shape (N, odim, (T-7)//2, (idim-3)//4) b, c, t, f = x.size() x = x.transpose(1, 2).reshape(b, t, c * f) - # now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels)) + # now x: (N, (T-7)//2, out_width * layer3_channels)) x = self.out(x) - # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + # Now x is of shape (N, (T-7)//2, odim) x = self.out_whiten(x) x = self.out_norm(x) x = self.dropout(x) @@ -328,7 +331,7 @@ class Conv2dSubsampling(nn.Module): with warnings.catch_warnings(): warnings.simplefilter("ignore") x_lens = (x_lens - 7) // 2 - assert x.size(1) == x_lens.max().item() + assert x.size(1) == x_lens.max().item() , (x.size(1), x_lens.max()) return x, x_lens @@ -347,7 +350,7 @@ class Conv2dSubsampling(nn.Module): A tensor of shape (batch_size,) containing the number of frames in Returns: - - a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + - a tensor of shape (N, (T-7)//2, odim) - output lengths, of shape (batch_size,) - updated cache """ @@ -383,7 +386,7 @@ class Conv2dSubsampling(nn.Module): assert self.convnext.padding[0] == 3 x_lens = (x_lens - 7) // 2 - 3 - assert x.size(1) == x_lens.max().item() + assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max()) return x, x_lens, cached_left_pad diff --git a/egs/librispeech/ASR/zipformer/test_scaling.py b/egs/librispeech/ASR/zipformer/test_scaling.py new file mode 100755 index 000000000..5c04291e7 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/test_scaling.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 + +import matplotlib.pyplot as plt +import torch +from scaling import PiecewiseLinear, ScheduledFloat, SwooshL, SwooshR + + +def test_piecewise_linear(): + # An identity map in the range [0, 1]. + # 1 - identity map in the range [1, 2] + # x1=0, y1=0 + # x2=1, y2=1 + # x3=2, y3=0 + pl = PiecewiseLinear((0, 0), (1, 1), (2, 0)) + assert pl(0.25) == 0.25, pl(0.25) + assert pl(0.625) == 0.625, pl(0.625) + assert pl(1.25) == 0.75, pl(1.25) + + assert pl(-10) == pl(0), pl(-10) # out of range + assert pl(10) == pl(2), pl(10) # out of range + + # multiplication + pl10 = pl * 10 + assert pl10(1) == 10 * pl(1) + assert pl10(0.5) == 10 * pl(0.5) + + +def test_scheduled_float(): + # Initial value is 0.2 and it decreases linearly towards 0 at 4000 + dropout = ScheduledFloat((0, 0.2), (4000, 0.0), default=0.0) + dropout.batch_count = 0 + assert float(dropout) == 0.2, (float(dropout), dropout.batch_count) + + dropout.batch_count = 1000 + assert abs(float(dropout) - 0.15) < 1e-5, (float(dropout), dropout.batch_count) + + dropout.batch_count = 2000 + assert float(dropout) == 0.1, (float(dropout), dropout.batch_count) + + dropout.batch_count = 3000 + assert abs(float(dropout) - 0.05) < 1e-5, (float(dropout), dropout.batch_count) + + dropout.batch_count = 4000 + assert float(dropout) == 0.0, (float(dropout), dropout.batch_count) + + dropout.batch_count = 5000 # out of range + assert float(dropout) == 0.0, (float(dropout), dropout.batch_count) + + +def test_swoosh(): + x1 = torch.linspace(start=-10, end=0, steps=100, dtype=torch.float32) + x2 = torch.linspace(start=0, end=10, steps=100, dtype=torch.float32) + x = torch.cat([x1, x2[1:]]) + + left = SwooshL()(x) + r = SwooshR()(x) + + relu = torch.nn.functional.relu(x) + print(left[x == 0], r[x == 0]) + plt.plot(x, left, "k") + plt.plot(x, r, "r") + plt.plot(x, relu, "b") + plt.axis([-10, 10, -1, 10]) # [xmin, xmax, ymin, ymax] + plt.legend( + [ + "SwooshL(x) = log(1 + exp(x-4)) - 0.08x - 0.035 ", + "SwooshR(x) = log(1 + exp(x-1)) - 0.08x - 0.313261687", + "ReLU(x) = max(0, x)", + ] + ) + plt.grid() + plt.savefig("swoosh.pdf") + + +def main(): + test_piecewise_linear() + test_scheduled_float() + test_swoosh() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer/test_subsampling.py b/egs/librispeech/ASR/zipformer/test_subsampling.py new file mode 100755 index 000000000..078227fb6 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/test_subsampling.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 + +import torch +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling + + +def test_conv2d_subsampling(): + layer1_channels = 8 + layer2_channels = 32 + layer3_channels = 128 + + out_channels = 192 + encoder_embed = Conv2dSubsampling( + in_channels=80, + out_channels=out_channels, + layer1_channels=layer1_channels, + layer2_channels=layer2_channels, + layer3_channels=layer3_channels, + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + N = 2 + T = 200 + num_features = 80 + x = torch.rand(N, T, num_features) + x_copy = x.clone() + + x = x.unsqueeze(1) # (N, 1, T, num_features) + + x = encoder_embed.conv[0](x) # conv2d, in 1, out 8, kernel 3, padding (0,1) + assert x.shape == (N, layer1_channels, T - 2, num_features) + # (2, 8, 198, 80) + + x = encoder_embed.conv[1](x) # scale grad + x = encoder_embed.conv[2](x) # balancer + x = encoder_embed.conv[3](x) # swooshR + + x = encoder_embed.conv[4](x) # conv2d, in 8, out 32, kernel 3, stride 2 + assert x.shape == ( + N, + layer2_channels, + ((T - 2) - 3) // 2 + 1, + (num_features - 3) // 2 + 1, + ) + # (2, 32, 98, 39) + + x = encoder_embed.conv[5](x) # balancer + x = encoder_embed.conv[6](x) # swooshR + + # conv2d: + # in 32, out 128, kernel 3, stride (1, 2) + x = encoder_embed.conv[7](x) + assert x.shape == ( + N, + layer3_channels, + (((T - 2) - 3) // 2 + 1) - 2, + (((num_features - 3) // 2 + 1) - 3) // 2 + 1, + ) + # (2, 128, 96, 19) + + x = encoder_embed.conv[8](x) # balancer + x = encoder_embed.conv[9](x) # swooshR + + # (((T - 2) - 3) // 2 + 1) - 2 + # = (T - 2) - 3) // 2 + 1 - 2 + # = ((T - 2) - 3) // 2 - 1 + # = (T - 2 - 3) // 2 - 1 + # = (T - 5) // 2 - 1 + # = (T - 7) // 2 + assert x.shape[2] == (x_copy.shape[1] - 7) // 2 + + # (((num_features - 3) // 2 + 1) - 3) // 2 + 1, + # = ((num_features - 3) // 2 + 1 - 3) // 2 + 1, + # = ((num_features - 3) // 2 - 2) // 2 + 1, + # = (num_features - 3 - 4) // 2 // 2 + 1, + # = (num_features - 7) // 2 // 2 + 1, + # = (num_features - 7) // 4 + 1, + # = (num_features - 3) // 4 + assert x.shape[3] == (x_copy.shape[2] - 3) // 4 + + assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) + + # Input shape to convnext is + # + # (N, layer3_channels, (T-7)//2, (num_features - 3)//4) + + # conv2d: in layer3_channels, out layer3_channels, groups layer3_channels + # kernel_size 7, padding 3 + x = encoder_embed.convnext.depthwise_conv(x) + assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) + + # conv2d: in layer3_channels, out hidden_ratio * layer3_channels, kernel_size 1 + x = encoder_embed.convnext.pointwise_conv1(x) + assert x.shape == (N, layer3_channels * 3, (T - 7) // 2, (num_features - 3) // 4) + + x = encoder_embed.convnext.hidden_balancer(x) # balancer + x = encoder_embed.convnext.activation(x) # swooshL + + # conv2d: in hidden_ratio * layer3_channels, out layer3_channels, kernel 1 + x = encoder_embed.convnext.pointwise_conv2(x) + assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) + + # bypass and layer drop, omitted here. + x = encoder_embed.convnext.out_balancer(x) + + # Note: the input and output shape of ConvNeXt are the same + + x = x.transpose(1, 2).reshape(N, (T - 7) // 2, -1) + assert x.shape == (N, (T - 7) // 2, layer3_channels * ((num_features - 3) // 4)) + + x = encoder_embed.out(x) + assert x.shape == (N, (T - 7) // 2, out_channels) + + x = encoder_embed.out_whiten(x) + x = encoder_embed.out_norm(x) + # final layer is dropout + + # test streaming forward + + subsampling_factor = 2 + cached_left_padding = encoder_embed.get_init_states(batch_size=N) + depthwise_conv_kernel_size = 7 + pad_size = (depthwise_conv_kernel_size - 1) // 2 + + assert cached_left_padding.shape == ( + N, + layer3_channels, + pad_size, + (num_features - 3) // 4, + ) + + chunk_size = 16 + right_padding = pad_size * subsampling_factor + T = chunk_size * subsampling_factor + 7 + right_padding + x = torch.rand(N, T, num_features) + x_lens = torch.tensor([T] * N) + y, y_lens, next_cached_left_padding = encoder_embed.streaming_forward( + x, x_lens, cached_left_padding + ) + + assert y.shape == (N, chunk_size, out_channels), y.shape + assert next_cached_left_padding.shape == cached_left_padding.shape + + assert y.shape[1] == y_lens[0] == y_lens[1] + + +def main(): + test_conv2d_subsampling() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7d98dbeb1..b39af02b8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -219,7 +219,7 @@ class Zipformer2(EncoderInterface): (num_frames0, batch_size, _encoder_dims0) = x.shape - assert self.encoder_dim[0] == _encoder_dims0 + assert self.encoder_dim[0] == _encoder_dims0, (self.encoder_dim[0], _encoder_dims0) feature_mask_dropout_prob = 0.125 @@ -334,7 +334,7 @@ class Zipformer2(EncoderInterface): x = self._get_full_dim_output(outputs) x = self.downsample_output(x) # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2 + assert self.output_downsampling_factor == 2, self.output_downsampling_factor if torch.jit.is_scripting() or torch.jit.is_tracing(): lengths = (x_lens + 1) // 2 else: