diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 6f6e9137d..0f5aad60b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -123,7 +123,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="4,4,6,4", + default="2,2,4,6,4,2", help="Number of zipformer encoder layers per stack, comma separated.", ) @@ -131,7 +131,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--downsampling-factor", type=str, - default="1,2,4,2", + default="1,2,4,8,4,2", help="Downsampling factor for each stack of encoder layers.", ) @@ -139,14 +139,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--feedforward-dim", type=str, - default="1536,1536,1536,1536", + default="384,768,1024,1536,1024,768", help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", ) parser.add_argument( "--num-heads", type=str, - default="8,8,8,8", + default="4,4,4,8,4,4", help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", ) @@ -160,7 +160,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--encoder-dim", type=str, - default="384", + default="192,256,320,384,320,256", help="Embedding dimension in encoder stacks: a single int or comma-separated list." ) @@ -195,7 +195,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--encoder-unmasked-dim", type=str, - default="256", + default="164,192,256,256,256,192", help="Unmasked dimensions in the encoders, relates to augmentation during training. " "A single int or comma-separated list. Must be <= each corresponding encoder_dim." ) @@ -203,7 +203,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--cnn-module-kernel", type=str, - default="31", + default="31,31,15,15,15,31", help="Sizes of convolutional kernels in convolution modules in each encoder stack: " "a single int or comma-separated list.", )