From cf50e16047c76c6c44a81f30bab0417b92028d8c Mon Sep 17 00:00:00 2001 From: PingFeng Luo Date: Wed, 1 Dec 2021 18:19:03 +0800 Subject: [PATCH] export model --- egs/aishell/ASR/RESULTS.md | 42 +++++++++++++++++++++++++ egs/aishell/ASR/conformer_mmi/decode.py | 10 ++++-- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 589d9ee7b..aaa581768 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -1,5 +1,47 @@ ## Results +### Aishell training results (Conformer-MMI) +#### 2021-12-01 +(Pingfeng Luo): Result of + +The tensorboard log for training is available at + +And pretrained model is available at + +The best decoding results (CER) are listed below, we got this results by averaging models from epoch 20 to 49, and using `attention-decoder` decoder with num_paths equals to 100. + +||test| +|--|--| +|CER| 5.12% | + +||lm_scale|attention_scale| +|--|--|--| +|test|1.5|0.5| + +You can use the following commands to reproduce our results: + +```bash +git clone https://github.com/k2-fsa/icefall +cd icefall + +cd egs/aishell/ASR +./prepare.sh + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8" +python conformer_ctc/train.py --bucketing-sampler True \ + --max-duration 200 \ + --start-epoch 0 \ + --num-epoch 50 \ + --world-size 8 + +python3 conformer_ctc/decode.py --nbest-scale 0.5 \ + --epoch 49 \ + --avg 20 \ + --method attention-decoder \ + --max-duration 20 \ + --num-paths 100 +``` + ### Aishell training results (Conformer-CTC) #### 2021-11-16 (Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/30 diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py index 9c42a8f38..a80717b75 100755 --- a/egs/aishell/ASR/conformer_mmi/decode.py +++ b/egs/aishell/ASR/conformer_mmi/decode.py @@ -542,8 +542,14 @@ def main(): if start >= 0: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model.load_state_dict(average_checkpoints(filenames)) + + if params.export: + logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) + return model.to(device) model.eval()