From 807816fec0dde1bfa0f0a2f20d36552cc3d84a90 Mon Sep 17 00:00:00 2001 From: Erwan Zerhouni <61225408+ezerhouni@users.noreply.github.com> Date: Wed, 18 Oct 2023 10:07:10 +0200 Subject: [PATCH] Fix chunk issue for sherpa (#1316) --- egs/librispeech/ASR/zipformer/zipformer.py | 31 +++++++++++++--------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 1a174b315..61ae378d8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -17,28 +17,33 @@ # limitations under the License. import copy +import logging import math +import random import warnings from typing import List, Optional, Tuple, Union -import logging + import torch -import random from encoder_interface import EncoderInterface from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. +) +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationDropoutAndLinear, Balancer, BiasNorm, - Dropout2, ChunkCausalDepthwiseConv1d, - ActivationDropoutAndLinear, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + Dropout2, + FloatLike, + ScheduledFloat, Whiten, - Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + convert_num_channels, + limit_param_value, penalize_abs_values_gt, softmax, - ScheduledFloat, - FloatLike, - limit_param_value, - convert_num_channels, ) from torch import Tensor, nn @@ -2098,7 +2103,7 @@ class NonlinAttention(nn.Module): (seq_len, batch_size, _) = x.shape hidden_channels = self.hidden_channels - s, x, y = x.chunk(3, dim=-1) + s, x, y = x.chunk(3, dim=2) # s will go through tanh. @@ -2151,7 +2156,7 @@ class NonlinAttention(nn.Module): (seq_len, batch_size, _) = x.shape hidden_channels = self.hidden_channels - s, x, y = x.chunk(3, dim=-1) + s, x, y = x.chunk(3, dim=2) # s will go through tanh. s = self.tanh(s) @@ -2308,7 +2313,7 @@ class ConvolutionModule(nn.Module): x = self.in_proj(x) # (time, batch, 2*channels) - x, s = x.chunk(2, dim=-1) + x, s = x.chunk(2, dim=2) s = self.balancer1(s) s = self.sigmoid(s) x = self.activation1(x) # identity.