diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index c03598895..8e4733bfc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -93,35 +93,35 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="7,7", + default="7,4,4", help="Number of zipformer encoder layers, comma separated.", ) parser.add_argument( "--feedforward-dims", type=str, - default="1536,1536", + default="1536,1536,1536", help="Feedforward dimension of the zipformer encoder layers, comma separated.", ) parser.add_argument( "--nhead", type=str, - default="8,8", + default="8,8,8", help="Number of attention heads in the zipformer encoder layers.", ) parser.add_argument( "--encoder-dims", type=str, - default="384,384", + default="384,384,512", help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated" ) parser.add_argument( "--attention-dims", type=str, - default="192,192", + default="192,192,256", help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; not the same as embedding dimension.""" ) @@ -129,7 +129,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--encoder-unmasked-dims", type=str, - default="256,256", + default="256,256,256", help="Unmasked dimensions in the encoders, relates to augmentation during training. " "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " " worse." @@ -138,10 +138,17 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--zipformer-downsampling-factors", type=str, - default="1,2", + default="1,2,4", help="Downsampling factor for each stack of encoder layers.", ) + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31", + help="Sizes of kernels in convolution modules", + ) + parser.add_argument( "--decoder-dim", type=int, @@ -443,6 +450,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), nhead=to_int_tuple(params.nhead), feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), num_encoder_layers=to_int_tuple(params.num_encoder_layers), ) return encoder diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 70f1a71d0..01e2b0b1c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -72,7 +72,7 @@ class Zipformer(EncoderInterface): feedforward_dim: Tuple[int] = (1536, 2048), num_encoder_layers: Tuple[int] = (12, 12), dropout: float = 0.1, - cnn_module_kernel: Tuple[int] = (31, 31), + cnn_module_kernels: Tuple[int] = (31, 31), warmup_batches: float = 4000.0, ) -> None: super(Zipformer, self).__init__() @@ -111,7 +111,7 @@ class Zipformer(EncoderInterface): nhead[i], feedforward_dim[i], dropout, - cnn_module_kernel[i], + cnn_module_kernels[i], ) # For the segment of the warmup period, we let the Conv2dSubsampling