mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
export model
This commit is contained in:
parent
4b6edaa4a3
commit
cf50e16047
@ -1,5 +1,47 @@
|
||||
## Results
|
||||
|
||||
### Aishell training results (Conformer-MMI)
|
||||
#### 2021-12-01
|
||||
(Pingfeng Luo): Result of <https://github.com/k2-fsa/icefall/pull/123>
|
||||
|
||||
The tensorboard log for training is available at <https://tensorboard.dev/experiment/dyp3vWE9RE6SkqBAgLJjUw/>
|
||||
|
||||
And pretrained model is available at <https://huggingface.co/pfluo/icefall_aishell_model>
|
||||
|
||||
The best decoding results (CER) are listed below, we got this results by averaging models from epoch 20 to 49, and using `attention-decoder` decoder with num_paths equals to 100.
|
||||
|
||||
||test|
|
||||
|--|--|
|
||||
|CER| 5.12% |
|
||||
|
||||
||lm_scale|attention_scale|
|
||||
|--|--|--|
|
||||
|test|1.5|0.5|
|
||||
|
||||
You can use the following commands to reproduce our results:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/k2-fsa/icefall
|
||||
cd icefall
|
||||
|
||||
cd egs/aishell/ASR
|
||||
./prepare.sh
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8"
|
||||
python conformer_ctc/train.py --bucketing-sampler True \
|
||||
--max-duration 200 \
|
||||
--start-epoch 0 \
|
||||
--num-epoch 50 \
|
||||
--world-size 8
|
||||
|
||||
python3 conformer_ctc/decode.py --nbest-scale 0.5 \
|
||||
--epoch 49 \
|
||||
--avg 20 \
|
||||
--method attention-decoder \
|
||||
--max-duration 20 \
|
||||
--num-paths 100
|
||||
```
|
||||
|
||||
### Aishell training results (Conformer-CTC)
|
||||
#### 2021-11-16
|
||||
(Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/30
|
||||
|
@ -542,8 +542,14 @@ def main():
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
|
||||
if params.export:
|
||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||
torch.save(
|
||||
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||
)
|
||||
return
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
Loading…
x
Reference in New Issue
Block a user