diff --git a/egs/spgispeech/ASR/RESULTS.md b/egs/spgispeech/ASR/RESULTS.md index 5086a4ab8..f5997c408 100644 --- a/egs/spgispeech/ASR/RESULTS.md +++ b/egs/spgispeech/ASR/RESULTS.md @@ -40,17 +40,43 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" --use-fp16 True ``` -The tensorboard training log can be found at - - The decoding command is: ``` -## fast beam search -./pruned_transducer_stateless/decode.py \ - --avg-last-n 10 \ - --exp-dir pruned_transducer_stateless/exp \ - --max-duration 500 \ - --beam-size 4 \ - --max-contexts 4 \ - --max-states 8 +# greedy search +./pruned_transducer_stateless2/decode.py \ + --avg-last-n 10 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 100 \ + --decoding-method greedy_search + +# beam search +./pruned_transducer_stateless2/decode.py \ + --avg-last-n 10 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 + +# modified beam search +./pruned_transducer_stateless2/decode.py \ + --avg-last-n 10 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +# fast beam search +./pruned_transducer_stateless2/decode.py \ + --avg-last-n 10 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 ``` + +Pretrained model is available at + +The tensorboard training log can be found at + diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py index b5757ee8c..a5eca6e2d 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py @@ -23,8 +23,7 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 + --avg-last-n 10 It will generate a file exp_dir/pretrained.pt @@ -34,7 +33,7 @@ you can do: cd /path/to/exp_dir ln -s pretrained.pt epoch-9999.pt - cd /path/to/egs/librispeech/ASR + cd /path/to/egs/spgispeech/ASR ./pruned_transducer_stateless2/decode.py \ --exp-dir ./pruned_transducer_stateless2/exp \ --epoch 9999 \ @@ -51,7 +50,11 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -77,6 +80,17 @@ def get_parser(): "'--epoch'. ", ) + parser.add_argument( + "--avg-last-n", + type=int, + default=0, + help="""If positive, --epoch and --avg are ignored and it + will use the last n checkpoints exp_dir/checkpoint-xxx.pt + where xxx is the number of processed batches while + saving that checkpoint. + """, + ) + parser.add_argument( "--exp-dir", type=str, @@ -105,8 +119,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -141,7 +154,12 @@ def main(): model.to(device) - if params.avg == 1: + if params.avg_last_n > 0: + filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: start = params.epoch - params.avg + 1 @@ -174,9 +192,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main()