From 02b4b469a263aa19d08a9876e12c33d7092f8309 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Fri, 13 May 2022 14:03:38 -0400 Subject: [PATCH] remove change in librispeech --- .../pruned_transducer_stateless2/decode.py | 80 ++++++++++++------- 1 file changed, 51 insertions(+), 29 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 38aff8834..05a4cdca5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -22,15 +22,15 @@ Usage: --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 100 \ + --max-duration 600 \ --decoding-method greedy_search -(2) beam search +(2) beam search (not recommended) ./pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 100 \ + --max-duration 600 \ --decoding-method beam_search \ --beam-size 4 @@ -39,7 +39,7 @@ Usage: --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 100 \ + --max-duration 600 \ --decoding-method modified_beam_search \ --beam-size 4 @@ -48,7 +48,7 @@ Usage: --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 1500 \ + --max-duration 600 \ --decoding-method fast_beam_search \ --beam 4 \ --max-contexts 4 \ @@ -69,7 +69,7 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, - fast_beam_search, + fast_beam_search_one_best, greedy_search, greedy_search_batch, modified_beam_search, @@ -98,27 +98,28 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + parser.add_argument( "--avg", type=int, default=15, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--avg-last-n", - type=int, - default=0, - help="""If positive, --epoch and --avg are ignored and it - will use the last n checkpoints exp_dir/checkpoint-xxx.pt - where xxx is the number of processed batches while - saving that checkpoint. - """, + "'--epoch' and '--iter'", ) parser.add_argument( @@ -151,7 +152,7 @@ def get_parser(): "--beam-size", type=int, default=4, - help="""An interger indicating how many candidates we will keep for each + help="""An integer indicating how many candidates we will keep for each frame. Used only when --decoding-method is beam_search or modified_beam_search.""", ) @@ -251,7 +252,7 @@ def decode_one_batch( hyps = [] if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search( + hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -269,6 +270,7 @@ def decode_one_batch( hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -276,6 +278,7 @@ def decode_one_batch( hyp_tokens = modified_beam_search( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) for hyp in sp.decode(hyp_tokens): @@ -355,9 +358,9 @@ def decode_dataset( num_batches = "?" if params.decoding_method == "greedy_search": - log_interval = 100 + log_interval = 50 else: - log_interval = 2 + log_interval = 10 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -453,13 +456,19 @@ def main(): ) params.res_dir = params.exp_dir / params.decoding_method - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -476,8 +485,9 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # is defined in local/train_bpe_model.py + # and is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) @@ -485,8 +495,20 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.avg_last_n > 0: - filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device))