mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
add pretrained model to HF
This commit is contained in:
parent
2cf0d51e89
commit
2381ba544d
@ -40,17 +40,43 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
|||||||
--use-fp16 True
|
--use-fp16 True
|
||||||
```
|
```
|
||||||
|
|
||||||
The tensorboard training log can be found at
|
|
||||||
<https://tensorboard.dev/experiment/ExSoBmrPRx6liMTGLu0Tgw/#scalars>
|
|
||||||
|
|
||||||
The decoding command is:
|
The decoding command is:
|
||||||
```
|
```
|
||||||
## fast beam search
|
# greedy search
|
||||||
./pruned_transducer_stateless/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--avg-last-n 10 \
|
--avg-last-n 10 \
|
||||||
--exp-dir pruned_transducer_stateless/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 500 \
|
--max-duration 100 \
|
||||||
--beam-size 4 \
|
--decoding-method greedy_search
|
||||||
--max-contexts 4 \
|
|
||||||
--max-states 8
|
# beam search
|
||||||
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
--avg-last-n 10 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--max-duration 100 \
|
||||||
|
--decoding-method beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
# modified beam search
|
||||||
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
--avg-last-n 10 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--max-duration 100 \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
# fast beam search
|
||||||
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
--avg-last-n 10 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--max-duration 1500 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Pretrained model is available at <https://huggingface.co/desh2608/icefall-asr-spgispeech-pruned-transducer-stateless2>
|
||||||
|
|
||||||
|
The tensorboard training log can be found at
|
||||||
|
<https://tensorboard.dev/experiment/ExSoBmrPRx6liMTGLu0Tgw/#scalars>
|
||||||
|
@ -23,8 +23,7 @@ Usage:
|
|||||||
./pruned_transducer_stateless2/export.py \
|
./pruned_transducer_stateless2/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
--epoch 20 \
|
--avg-last-n 10
|
||||||
--avg 10
|
|
||||||
|
|
||||||
It will generate a file exp_dir/pretrained.pt
|
It will generate a file exp_dir/pretrained.pt
|
||||||
|
|
||||||
@ -34,7 +33,7 @@ you can do:
|
|||||||
cd /path/to/exp_dir
|
cd /path/to/exp_dir
|
||||||
ln -s pretrained.pt epoch-9999.pt
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
cd /path/to/egs/librispeech/ASR
|
cd /path/to/egs/spgispeech/ASR
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
@ -51,7 +50,11 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
@ -77,6 +80,17 @@ def get_parser():
|
|||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg-last-n",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch and --avg are ignored and it
|
||||||
|
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
|
||||||
|
where xxx is the number of processed batches while
|
||||||
|
saving that checkpoint.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -105,8 +119,7 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; "
|
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||||
"2 means tri-gram",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -141,7 +154,12 @@ def main():
|
|||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
if params.avg == 1:
|
if params.avg_last_n > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
else:
|
else:
|
||||||
start = params.epoch - params.avg + 1
|
start = params.epoch - params.avg + 1
|
||||||
@ -174,9 +192,7 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = (
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user