From 890cd1ab7529a5d284f0b083fa6f7ec3b5a80d74 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 6 Sep 2022 10:23:40 +0800 Subject: [PATCH] fix bugs --- egs/librispeech/ASR/tdnn_lstm_ctc2/decode.py | 8 +++++++- egs/librispeech/ASR/tdnn_lstm_ctc2/model.py | 5 ++++- .../ASR/tdnn_lstm_ctc2/pretrained.py | 6 ++++++ egs/librispeech/ASR/tdnn_lstm_ctc2/train.py | 17 ++++++----------- 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc2/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc2/decode.py index f1aacb5e7..ba2db3e38 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc2/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc2/decode.py @@ -125,7 +125,7 @@ def get_parser(): def get_params() -> AttributeDict: params = AttributeDict( { - "exp_dir": Path("tdnn_lstm_ctc/exp/"), + "exp_dir": Path("tdnn_lstm_ctc2/exp/"), "lang_dir": Path("data/lang_phone"), "lm_dir": Path("data/lm"), "feature_dim": 80, @@ -136,6 +136,11 @@ def get_params() -> AttributeDict: "max_active_states": 10000, "use_double_scores": True, "env_info": get_env_info(), + "grad_norm_threshold": 10.0, + # For each sequence element in batch, its gradient will be + # filtered out if the gradient norm is larger than + # `grad_norm_threshold * median`, where `median` is the median + # value of gradient norms of all elememts in batch. } ) return params @@ -452,6 +457,7 @@ def main(): num_features=params.feature_dim, num_classes=max_phone_id + 1, # +1 for the blank symbol subsampling_factor=params.subsampling_factor, + grad_norm_threshold=params.grad_norm_threshold, ) if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc2/model.py b/egs/librispeech/ASR/tdnn_lstm_ctc2/model.py index f8e5f5e9b..7187bfbaf 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc2/model.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc2/model.py @@ -144,7 +144,10 @@ class TdnnLstm(nn.Module): [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] ) self.grad_filters = nn.ModuleList( - [GradientFilter(batch_dim=1, threshold=grad_norm_threshold)] + [ + GradientFilter(batch_dim=1, threshold=grad_norm_threshold) + for _ in range(5) + ] ) self.dropout = nn.Dropout(0.2) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc2/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc2/pretrained.py index 2baeb6bba..3db80e894 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc2/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc2/pretrained.py @@ -124,6 +124,11 @@ def get_params() -> AttributeDict: "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, + "grad_norm_threshold": 10.0, + # For each sequence element in batch, its gradient will be + # filtered out if the gradient norm is larger than + # `grad_norm_threshold * median`, where `median` is the median + # value of gradient norms of all elememts in batch. } ) return params @@ -172,6 +177,7 @@ def main(): num_features=params.feature_dim, num_classes=params.num_classes, subsampling_factor=params.subsampling_factor, + grad_norm_threshold=params.grad_norm_threshold, ) checkpoint = torch.load(args.checkpoint, map_location="cpu") diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc2/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc2/train.py index 9497ca09e..41dd77425 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc2/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc2/train.py @@ -19,7 +19,7 @@ """ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" - ./tdnn_lstm_ctc/train.py \ + ./tdnn_lstm_ctc2/train.py \ --world-size 4 \ --full-libri 1 \ --max-duration 300 \ @@ -112,16 +112,6 @@ def get_parser(): help="The seed for random generators intended for reproducibility", ) - parser.add_argument( - "--grad-norm-threshold", - type=float, - default=10.0, - help="""For each sequence element in batch, its gradient will be - filtered out if the gradient norm is larger than - `grad_norm_threshold * median`, where `median` is the median - value of gradient norms of all elememts in batch.""", - ) - return parser @@ -199,6 +189,11 @@ def get_params() -> AttributeDict: "reduction": "sum", "use_double_scores": True, "env_info": get_env_info(), + "grad_norm_threshold": 10.0, + # For each sequence element in batch, its gradient will be + # filtered out if the gradient norm is larger than + # `grad_norm_threshold * median`, where `median` is the median + # value of gradient norms of all elememts in batch. } )