add pretrained model to HF

This commit is contained in:
Desh Raj 2022-05-13 10:00:43 -04:00
parent 2cf0d51e89
commit 2381ba544d
2 changed files with 63 additions and 21 deletions

View File

@ -40,17 +40,43 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
--use-fp16 True
```
The tensorboard training log can be found at
<https://tensorboard.dev/experiment/ExSoBmrPRx6liMTGLu0Tgw/#scalars>
The decoding command is:
```
## fast beam search
./pruned_transducer_stateless/decode.py \
--avg-last-n 10 \
--exp-dir pruned_transducer_stateless/exp \
--max-duration 500 \
--beam-size 4 \
--max-contexts 4 \
--max-states 8
# greedy search
./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method greedy_search
# beam search
./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
# modified beam search
./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
# fast beam search
./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
```
Pretrained model is available at <https://huggingface.co/desh2608/icefall-asr-spgispeech-pruned-transducer-stateless2>
The tensorboard training log can be found at
<https://tensorboard.dev/experiment/ExSoBmrPRx6liMTGLu0Tgw/#scalars>

View File

@ -23,8 +23,7 @@ Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
--avg-last-n 10
It will generate a file exp_dir/pretrained.pt
@ -34,7 +33,7 @@ you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
cd /path/to/egs/spgispeech/ASR
./pruned_transducer_stateless2/decode.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--epoch 9999 \
@ -51,7 +50,11 @@ import sentencepiece as spm
import torch
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
@ -77,6 +80,17 @@ def get_parser():
"'--epoch'. ",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
@ -105,8 +119,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
return parser
@ -141,7 +154,12 @@ def main():
model.to(device)
if params.avg == 1:
if params.avg_last_n > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
@ -174,9 +192,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()