mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
export model
This commit is contained in:
parent
4b6edaa4a3
commit
cf50e16047
@ -1,5 +1,47 @@
|
|||||||
## Results
|
## Results
|
||||||
|
|
||||||
|
### Aishell training results (Conformer-MMI)
|
||||||
|
#### 2021-12-01
|
||||||
|
(Pingfeng Luo): Result of <https://github.com/k2-fsa/icefall/pull/123>
|
||||||
|
|
||||||
|
The tensorboard log for training is available at <https://tensorboard.dev/experiment/dyp3vWE9RE6SkqBAgLJjUw/>
|
||||||
|
|
||||||
|
And pretrained model is available at <https://huggingface.co/pfluo/icefall_aishell_model>
|
||||||
|
|
||||||
|
The best decoding results (CER) are listed below, we got this results by averaging models from epoch 20 to 49, and using `attention-decoder` decoder with num_paths equals to 100.
|
||||||
|
|
||||||
|
||test|
|
||||||
|
|--|--|
|
||||||
|
|CER| 5.12% |
|
||||||
|
|
||||||
|
||lm_scale|attention_scale|
|
||||||
|
|--|--|--|
|
||||||
|
|test|1.5|0.5|
|
||||||
|
|
||||||
|
You can use the following commands to reproduce our results:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/k2-fsa/icefall
|
||||||
|
cd icefall
|
||||||
|
|
||||||
|
cd egs/aishell/ASR
|
||||||
|
./prepare.sh
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8"
|
||||||
|
python conformer_ctc/train.py --bucketing-sampler True \
|
||||||
|
--max-duration 200 \
|
||||||
|
--start-epoch 0 \
|
||||||
|
--num-epoch 50 \
|
||||||
|
--world-size 8
|
||||||
|
|
||||||
|
python3 conformer_ctc/decode.py --nbest-scale 0.5 \
|
||||||
|
--epoch 49 \
|
||||||
|
--avg 20 \
|
||||||
|
--method attention-decoder \
|
||||||
|
--max-duration 20 \
|
||||||
|
--num-paths 100
|
||||||
|
```
|
||||||
|
|
||||||
### Aishell training results (Conformer-CTC)
|
### Aishell training results (Conformer-CTC)
|
||||||
#### 2021-11-16
|
#### 2021-11-16
|
||||||
(Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/30
|
(Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/30
|
||||||
|
@ -542,8 +542,14 @@ def main():
|
|||||||
if start >= 0:
|
if start >= 0:
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.load_state_dict(average_checkpoints(filenames))
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
|
||||||
|
if params.export:
|
||||||
|
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||||
|
torch.save(
|
||||||
|
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user