mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
add pretrained model to HF
This commit is contained in:
parent
2cf0d51e89
commit
2381ba544d
@ -40,17 +40,43 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
--use-fp16 True
|
||||
```
|
||||
|
||||
The tensorboard training log can be found at
|
||||
<https://tensorboard.dev/experiment/ExSoBmrPRx6liMTGLu0Tgw/#scalars>
|
||||
|
||||
The decoding command is:
|
||||
```
|
||||
## fast beam search
|
||||
./pruned_transducer_stateless/decode.py \
|
||||
--avg-last-n 10 \
|
||||
--exp-dir pruned_transducer_stateless/exp \
|
||||
--max-duration 500 \
|
||||
--beam-size 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
# greedy search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--avg-last-n 10 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
# beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--avg-last-n 10 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
# modified beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--avg-last-n 10 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
# fast beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--avg-last-n 10 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 1500 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
```
|
||||
|
||||
Pretrained model is available at <https://huggingface.co/desh2608/icefall-asr-spgispeech-pruned-transducer-stateless2>
|
||||
|
||||
The tensorboard training log can be found at
|
||||
<https://tensorboard.dev/experiment/ExSoBmrPRx6liMTGLu0Tgw/#scalars>
|
||||
|
@ -23,8 +23,7 @@ Usage:
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
--avg-last-n 10
|
||||
|
||||
It will generate a file exp_dir/pretrained.pt
|
||||
|
||||
@ -34,7 +33,7 @@ you can do:
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
cd /path/to/egs/spgispeech/ASR
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--epoch 9999 \
|
||||
@ -51,7 +50,11 @@ import sentencepiece as spm
|
||||
import torch
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
@ -77,6 +80,17 @@ def get_parser():
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg-last-n",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch and --avg are ignored and it
|
||||
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
|
||||
where xxx is the number of processed batches while
|
||||
saving that checkpoint.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -105,8 +119,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
return parser
|
||||
@ -141,7 +154,12 @@ def main():
|
||||
|
||||
model.to(device)
|
||||
|
||||
if params.avg == 1:
|
||||
if params.avg_last_n > 0:
|
||||
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
@ -174,9 +192,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user