From a6fb9772a86850edd2c9717c1c25f5d7019dfc6d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Nov 2022 13:24:30 +0800 Subject: [PATCH] Remove 4 layers. --- egs/librispeech/ASR/pruned_transducer_stateless7/train.py | 2 +- .../ASR/pruned_transducer_stateless7/zipformer.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index a2cbee71b..9350465e7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -106,7 +106,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="2,5,5,4,4,5", + default="2,4,4,3,3,4", help="Number of zipformer encoder layers per stack, comma separated.", ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index d19524b8e..71d49af88 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1684,6 +1684,9 @@ class Conv2dSubsampling(nn.Module): ) out_height = (((in_channels - 1) // 2) - 1) // 2 + self.scale = nn.Parameter(torch.ones(out_height * layer3_channels)) + self.scale_max = 1.0 + self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.01)) self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels, aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob()) @@ -1709,6 +1712,11 @@ class Conv2dSubsampling(nn.Module): x = x.transpose(1, 2).reshape(b, t, c * f) # now x: (N, ((T-1)//2 - 1))//2, out_height * layer3_channels)) + + x = x * limit_param_value(self.scale, + min=float(self.scale_min), + max=float(self.scale_max)) + x = self.out(x) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) x = self.dropout(x)