From 9630f9a3baca1406247fb5881044551b4d3b7f63 Mon Sep 17 00:00:00 2001 From: Guanbo Wang Date: Sun, 15 May 2022 00:57:40 -0400 Subject: [PATCH] Update GigaSpeech reults (#364) * Update decode.py * Update export.py * Update results * Update README.md --- README.md | 13 ++-- egs/gigaspeech/ASR/README.md | 2 +- egs/gigaspeech/ASR/RESULTS.md | 29 ++++----- .../pruned_transducer_stateless2/decode.py | 64 ++++++++++++------- .../pruned_transducer_stateless2/export.py | 42 ++++++++++-- 5 files changed, 102 insertions(+), 48 deletions(-) diff --git a/README.md b/README.md index 47bf0e212..8911a4336 100644 --- a/README.md +++ b/README.md @@ -200,19 +200,22 @@ We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless mod ### GigaSpeech +We provide two models for this recipe: [Conformer CTC model][GigaSpeech_conformer_ctc] +and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2]. + #### Conformer CTC | | Dev | Test | |-----|-------|-------| | WER | 10.47 | 10.58 | -#### Pruned stateless RNN-T +#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss | | Dev | Test | |----------------------|-------|-------| -| greedy search | 10.59 | 10.87 | -| fast beam search | 10.56 | 10.80 | -| modified beam search | 10.52 | 10.62 | +| greedy search | 10.51 | 10.73 | +| fast beam search | 10.50 | 10.69 | +| modified beam search | 10.40 | 10.51 | ## Deployment with C++ @@ -238,6 +241,8 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [TIMIT_tdnn_ligru_ctc]: egs/timit/ASR/tdnn_ligru_ctc [TED-LIUM3_transducer_stateless]: egs/tedlium3/ASR/transducer_stateless [TED-LIUM3_pruned_transducer_stateless]: egs/tedlium3/ASR/pruned_transducer_stateless +[GigaSpeech_conformer_ctc]: egs/gigaspeech/ASR/conformer_ctc +[GigaSpeech_pruned_transducer_stateless2]: egs/gigaspeech/ASR/pruned_transducer_stateless2 [yesno]: egs/yesno/ASR [librispeech]: egs/librispeech/ASR [aishell]: egs/aishell/ASR diff --git a/egs/gigaspeech/ASR/README.md b/egs/gigaspeech/ASR/README.md index 1fca69e8b..32a0457c6 100644 --- a/egs/gigaspeech/ASR/README.md +++ b/egs/gigaspeech/ASR/README.md @@ -16,6 +16,6 @@ ln -sfv /path/to/GigaSpeech download/GigaSpeech | | Dev | Test | |--------------------------------|-------|-------| | `conformer_ctc` | 10.47 | 10.58 | -| `pruned_transducer_stateless2` | 10.52 | 10.62 | +| `pruned_transducer_stateless2` | 10.40 | 10.51 | See [RESULTS](/egs/gigaspeech/ASR/RESULTS.md) for details. diff --git a/egs/gigaspeech/ASR/RESULTS.md b/egs/gigaspeech/ASR/RESULTS.md index de7b84202..7ab565844 100644 --- a/egs/gigaspeech/ASR/RESULTS.md +++ b/egs/gigaspeech/ASR/RESULTS.md @@ -11,13 +11,15 @@ decoder contains only an embedding layer, a Conv1d (with kernel size 2) and a linear layer (to transform tensor dim). k2 pruned RNN-T loss is used. +The best WER, as of 2022-05-12, for the gigaspeech is below + Results are: | | Dev | Test | |----------------------|-------|-------| -| greedy search | 10.59 | 10.87 | -| fast beam search | 10.56 | 10.80 | -| modified beam search | 10.52 | 10.62 | +| greedy search | 10.51 | 10.73 | +| fast beam search | 10.50 | 10.69 | +| modified beam search | 10.40 | 10.51 | To reproduce the above result, use the following commands for training: @@ -39,33 +41,30 @@ and the following commands for decoding: ```bash # greedy search ./pruned_transducer_stateless2/decode.py \ - --epoch 29 \ - --avg 11 \ + --iter 3488000 \ + --avg 20 \ --decoding-method greedy_search \ --exp-dir pruned_transducer_stateless2/exp \ --bpe-model data/lang_bpe_500/bpe.model \ - --max-duration 20 \ - --num-workers 1 + --max-duration 600 # fast beam search ./pruned_transducer_stateless2/decode.py \ - --epoch 29 \ - --avg 9 \ + --iter 3488000 \ + --avg 20 \ --decoding-method fast_beam_search \ --exp-dir pruned_transducer_stateless2/exp \ --bpe-model data/lang_bpe_500/bpe.model \ - --max-duration 20 \ - --num-workers 1 + --max-duration 600 # modified beam search ./pruned_transducer_stateless2/decode.py \ - --epoch 29 \ - --avg 8 \ + --iter 3488000 \ + --avg 15 \ --decoding-method modified_beam_search \ --exp-dir pruned_transducer_stateless2/exp \ --bpe-model data/lang_bpe_500/bpe.model \ - --max-duration 20 \ - --num-workers 1 + --max-duration 600 ``` Pretrained model is available at diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 92a5b0b28..ce5116336 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -22,7 +22,7 @@ Usage: --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 100 \ + --max-duration 600 \ --decoding-method greedy_search (2) beam search @@ -30,7 +30,7 @@ Usage: --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 100 \ + --max-duration 600 \ --decoding-method beam_search \ --beam-size 4 @@ -39,7 +39,7 @@ Usage: --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 100 \ + --max-duration 600 \ --decoding-method modified_beam_search \ --beam-size 4 @@ -48,7 +48,7 @@ Usage: --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 1500 \ + --max-duration 600 \ --decoding-method fast_beam_search \ --beam 4 \ --max-contexts 4 \ @@ -99,27 +99,28 @@ def get_parser(): "--epoch", type=int, default=29, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + parser.add_argument( "--avg", type=int, default=8, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--avg-last-n", - type=int, - default=0, - help="""If positive, --epoch and --avg are ignored and it - will use the last n checkpoints exp_dir/checkpoint-xxx.pt - where xxx is the number of processed batches while - saving that checkpoint. - """, + "'--epoch' and '--iter'", ) parser.add_argument( @@ -152,7 +153,7 @@ def get_parser(): "--beam-size", type=int, default=4, - help="""An interger indicating how many candidates we will keep for each + help="""An integer indicating how many candidates we will keep for each frame. Used only when --decoding-method is beam_search or modified_beam_search.""", ) @@ -465,7 +466,11 @@ def main(): ) params.res_dir = params.exp_dir / params.decoding_method - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -488,8 +493,9 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # is defined in local/train_bpe_model.py + # and is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) @@ -497,8 +503,20 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.avg_last_n > 0: - filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py index b5757ee8c..6b3a7a9ff 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py @@ -51,7 +51,11 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -64,8 +68,19 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, ) parser.add_argument( @@ -74,7 +89,7 @@ def get_parser(): default=15, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + "'--epoch' and '--iter'", ) parser.add_argument( @@ -141,7 +156,24 @@ def main(): model.to(device) - if params.avg == 1: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: start = params.epoch - params.avg + 1