From c57798661c968ce2f69ceccb06d26fa7b97189ae Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 24 Dec 2021 14:12:32 +0800 Subject: [PATCH] Make --context-size configurable. --- egs/librispeech/ASR/transducer_stateless/decode.py | 12 ++++++++++-- egs/librispeech/ASR/transducer_stateless/export.py | 10 ++++++++-- .../ASR/transducer_stateless/pretrained.py | 10 ++++++++-- egs/librispeech/ASR/transducer_stateless/train.py | 10 ++++++++-- 4 files changed, 34 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index 88cb2ffc1..fba50a241 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -114,6 +114,14 @@ def get_parser(): help="Used only when --decoding-method is beam_search", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + return parser @@ -129,8 +137,6 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - # parameters for decoder - "context_size": 2, # tri-gram "env_info": get_env_info(), } ) @@ -379,6 +385,8 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.decoding_method == "beam_search": params.suffix += f"-beam-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py index 0367ecf64..641555bdb 100755 --- a/egs/librispeech/ASR/transducer_stateless/export.py +++ b/egs/librispeech/ASR/transducer_stateless/export.py @@ -104,6 +104,14 @@ def get_parser(): """, ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + return parser @@ -119,8 +127,6 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - # parameters for decoder - "context_size": 2, # tri-gram "env_info": get_env_info(), } ) diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index 5be3a0944..77046bea9 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -110,6 +110,14 @@ def get_parser(): help="Used only when --method is beam_search", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + return parser @@ -126,8 +134,6 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - # parameters for decoder - "context_size": 2, # tri-gram "env_info": get_env_info(), } ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index f67b28061..694ebf1d5 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -130,6 +130,14 @@ def get_parser(): help="The lr_factor for Noam optimizer", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + return parser @@ -196,8 +204,6 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - # parameters for decoder - "context_size": 2, # tri-gram # parameters for Noam "warm_step": 80000, # For the 100h subset, use 8k "env_info": get_env_info(),