From aab72bc2a546872ac08a4396b382810b90af1cba Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 24 Mar 2022 13:10:54 +0800 Subject: [PATCH] Add changes from master to decode.py, train.py --- .../pruned_transducer_stateless2/decode.py | 27 ++++++++++++++----- .../ASR/pruned_transducer_stateless2/train.py | 19 ++++++++++--- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index ad76411c0..8e924bf96 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -71,6 +71,7 @@ from beam_search import ( beam_search, fast_beam_search, greedy_search, + greedy_search_batch, modified_beam_search, ) from train import get_params, get_transducer_model @@ -191,7 +192,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --decoding_method is greedy_search""", ) @@ -261,6 +262,24 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) @@ -280,12 +299,6 @@ def decode_one_batch( encoder_out=encoder_out_i, beam=params.beam_size, ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index d4a2e83d5..01cf289f5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -398,12 +398,16 @@ def load_checkpoint_if_available( "batch_idx_train", "best_train_loss", "best_valid_loss", - "cur_batch_idx", ] for k in keys: params[k] = saved_params[k] - params["start_epoch"] = saved_params["cur_epoch"] + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] return saved_params @@ -762,11 +766,20 @@ def run(rank, world_size, args): def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold return 1.0 <= c.duration <= 20.0 train_cuts = train_cuts.filter(remove_short_and_long_utt) - if checkpoints and "sampler" in checkpoints: + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch sampler_state_dict = checkpoints["sampler"] else: sampler_state_dict = None