diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 05a4cdca5..38aff8834 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -22,15 +22,15 @@ Usage: --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ + --max-duration 100 \ --decoding-method greedy_search -(2) beam search (not recommended) +(2) beam search ./pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ + --max-duration 100 \ --decoding-method beam_search \ --beam-size 4 @@ -39,7 +39,7 @@ Usage: --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ + --max-duration 100 \ --decoding-method modified_beam_search \ --beam-size 4 @@ -48,7 +48,7 @@ Usage: --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ + --max-duration 1500 \ --decoding-method fast_beam_search \ --beam 4 \ --max-contexts 4 \ @@ -69,7 +69,7 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, - fast_beam_search_one_best, + fast_beam_search, greedy_search, greedy_search_batch, modified_beam_search, @@ -98,28 +98,27 @@ def get_parser(): "--epoch", type=int, default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - parser.add_argument( "--avg", type=int, default=15, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + "'--epoch'. ", + ) + + parser.add_argument( + "--avg-last-n", + type=int, + default=0, + help="""If positive, --epoch and --avg are ignored and it + will use the last n checkpoints exp_dir/checkpoint-xxx.pt + where xxx is the number of processed batches while + saving that checkpoint. + """, ) parser.add_argument( @@ -152,7 +151,7 @@ def get_parser(): "--beam-size", type=int, default=4, - help="""An integer indicating how many candidates we will keep for each + help="""An interger indicating how many candidates we will keep for each frame. Used only when --decoding-method is beam_search or modified_beam_search.""", ) @@ -252,7 +251,7 @@ def decode_one_batch( hyps = [] if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( + hyp_tokens = fast_beam_search( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -270,7 +269,6 @@ def decode_one_batch( hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -278,7 +276,6 @@ def decode_one_batch( hyp_tokens = modified_beam_search( model=model, encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) for hyp in sp.decode(hyp_tokens): @@ -358,9 +355,9 @@ def decode_dataset( num_batches = "?" if params.decoding_method == "greedy_search": - log_interval = 50 + log_interval = 100 else: - log_interval = 10 + log_interval = 2 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -456,19 +453,13 @@ def main(): ) params.res_dir = params.exp_dir / params.decoding_method - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-beam-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -485,9 +476,8 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and is defined in local/train_bpe_model.py + # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) @@ -495,20 +485,8 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) + if params.avg_last_n > 0: + filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) diff --git a/egs/spgispeech/ASR/README.md b/egs/spgispeech/ASR/README.md index 462109493..67f21bba8 100644 --- a/egs/spgispeech/ASR/README.md +++ b/egs/spgispeech/ASR/README.md @@ -1,34 +1,33 @@ # SPGISpeech -SPGISpeech consists of 5,000 hours of recorded company earnings calls and their respective -transcriptions. The original calls were split into slices ranging from 5 to 15 seconds in -length to allow easy training for speech recognition systems. Calls represent a broad -cross-section of international business English; SPGISpeech contains approximately 50,000 -speakers, one of the largest numbers of any speech corpus, and offers a variety of L1 and +SPGISpeech consists of 5,000 hours of recorded company earnings calls and their respective +transcriptions. The original calls were split into slices ranging from 5 to 15 seconds in +length to allow easy training for speech recognition systems. Calls represent a broad +cross-section of international business English; SPGISpeech contains approximately 50,000 +speakers, one of the largest numbers of any speech corpus, and offers a variety of L1 and L2 English accents. The format of each WAV file is single channel, 16kHz, 16 bit audio. -Transcription text represents the output of several stages of manual post-processing. -As such, the text contains polished English orthography following a detailed style guide, -including proper casing, punctuation, and denormalized non-standard words such as numbers +Transcription text represents the output of several stages of manual post-processing. +As such, the text contains polished English orthography following a detailed style guide, +including proper casing, punctuation, and denormalized non-standard words such as numbers and acronyms, making SPGISpeech suited for training fully formatted end-to-end models. Official reference: -O’Neill, P.K., Lavrukhin, V., Majumdar, S., Noroozi, V., Zhang, Y., Kuchaiev, O., Balam, -J., Dovzhenko, Y., Freyberg, K., Shulman, M.D., Ginsburg, B., Watanabe, S., & Kucsko, G. -(2021). SPGISpeech: 5, 000 hours of transcribed financial audio for fully formatted +O’Neill, P.K., Lavrukhin, V., Majumdar, S., Noroozi, V., Zhang, Y., Kuchaiev, O., Balam, +J., Dovzhenko, Y., Freyberg, K., Shulman, M.D., Ginsburg, B., Watanabe, S., & Kucsko, G. +(2021). SPGISpeech: 5, 000 hours of transcribed financial audio for fully formatted end-to-end speech recognition. ArXiv, abs/2104.02014. ArXiv link: https://arxiv.org/abs/2104.02014 ## Performance Record -| Decoding method | val | +| Decoding method | val | |---------------------------|------------| | greedy search | 2.40 | | beam search | 2.24 | -| modified beam search | 2.30 | +| modified beam search | 2.24 | | fast beam search | 2.35 | See [RESULTS](/egs/spgispeech/ASR/RESULTS.md) for details. - diff --git a/egs/spgispeech/ASR/RESULTS.md b/egs/spgispeech/ASR/RESULTS.md index f5997c408..c63b8ce90 100644 --- a/egs/spgispeech/ASR/RESULTS.md +++ b/egs/spgispeech/ASR/RESULTS.md @@ -16,7 +16,7 @@ The WERs are |---------------------------|------------|------------|------------------------------------------| | greedy search | 2.46 | 2.40 | --avg-last-n 10 --max-duration 500 | | beam search | 2.27 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 | -| modified beam search | 2.34 | 2.30 | --avg-last-n 10 --max-duration 500 --beam-size 4 | +| modified beam search | 2.28 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 | | fast beam search | 2.38 | 2.35 | --avg-last-n 10 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | **NOTE:** SPGISpeech transcripts can be prepared in `ortho` or `norm` ways, which refer to whether the @@ -44,14 +44,14 @@ The decoding command is: ``` # greedy search ./pruned_transducer_stateless2/decode.py \ - --avg-last-n 10 \ + --iter 696000 --avg 10 \ --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 100 \ --decoding-method greedy_search # beam search ./pruned_transducer_stateless2/decode.py \ - --avg-last-n 10 \ + --iter 696000 --avg 10 \ --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 100 \ --decoding-method beam_search \ @@ -59,7 +59,7 @@ The decoding command is: # modified beam search ./pruned_transducer_stateless2/decode.py \ - --avg-last-n 10 \ + --iter 696000 --avg 10 \ --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 100 \ --decoding-method modified_beam_search \ @@ -67,7 +67,7 @@ The decoding command is: # fast beam search ./pruned_transducer_stateless2/decode.py \ - --avg-last-n 10 \ + --iter 696000 --avg 10 \ --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 1500 \ --decoding-method fast_beam_search \ diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py index 491be9083..86626f058 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py @@ -19,14 +19,16 @@ Usage: (1) greedy search ./pruned_transducer_stateless2/decode.py \ - --avg-last-n 10 \ + --iter 696000 \ + --avg 10 \ --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 100 \ --decoding-method greedy_search (2) beam search ./pruned_transducer_stateless2/decode.py \ - --avg-last-n 10 \ + --iter 696000 \ + --avg 10 \ --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 100 \ --decoding-method beam_search \ @@ -34,7 +36,8 @@ Usage: (3) modified beam search ./pruned_transducer_stateless2/decode.py \ - --avg-last-n 10 \ + --iter 696000 \ + --avg 10 \ --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 100 \ --decoding-method modified_beam_search \ @@ -42,7 +45,8 @@ Usage: (4) fast beam search ./pruned_transducer_stateless2/decode.py \ - --avg-last-n 10 \ + --iter 696000 \ + --avg 10 \ --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 1500 \ --decoding-method fast_beam_search \ @@ -93,30 +97,31 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + default=20, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", ) parser.add_argument( - "--avg-last-n", + "--iter", type=int, default=0, - help="""If positive, --epoch and --avg are ignored and it - will use the last n checkpoints exp_dir/checkpoint-xxx.pt - where xxx is the number of processed batches while - saving that checkpoint. + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. """, ) + parser.add_argument( + "--avg", + type=int, + default=10, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + parser.add_argument( "--exp-dir", type=str, @@ -182,7 +187,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -240,7 +246,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -255,10 +263,14 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -266,6 +278,7 @@ def decode_one_batch( hyp_tokens = modified_beam_search( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) for hyp in sp.decode(hyp_tokens): @@ -375,7 +388,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -407,7 +422,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -440,13 +456,19 @@ def main(): ) params.res_dir = params.exp_dir / params.decoding_method - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -472,8 +494,20 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.avg_last_n > 0: - filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py index f2a5c054d..6c66bfb62 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py @@ -60,7 +60,6 @@ from asr_datamodule import SPGISpeechAsrDataModule from conformer import Conformer from decoder import Decoder from joiner import Joiner -from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer @@ -78,7 +77,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -154,7 +155,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need to be " + "changed.", ) parser.add_argument( @@ -177,7 +179,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -200,7 +203,8 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( @@ -550,16 +554,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -722,7 +733,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -820,7 +833,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -828,9 +841,10 @@ def run(rank, world_size, args): train_cuts = spgispeech.train_cuts() - # Ideally we should filter utterances that are too long or too short, but SPGISpeech - # contains regular length utterances so we don't need to do that. Here are the - # statistics of the training data (obtained by `train_cuts.describe()`): + # Ideally we should filter utterances that are too long or too short, + # but SPGISpeech contains regular length utterances so we don't need to + # do that. Here are the statistics of the training data (obtained by + # `train_cuts.describe()`): # Cuts count: 5886320 # Total duration (hours): 15070.1