mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
modify export.py
This commit is contained in:
parent
1a9af61497
commit
626a2912ca
@ -20,28 +20,42 @@
|
|||||||
# to a single one using model averaging.
|
# to a single one using model averaging.
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
./pruned_transducer_stateless4/export.py \
|
./conv_emformer_transducer_stateless/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
--exp-dir ./conv_emformer_transducer_stateless/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
--epoch 20 \
|
--epoch 30 \
|
||||||
--avg 10
|
--avg 10 \
|
||||||
|
--use-averaged-model=True \
|
||||||
|
--num-encoder-layers 12 \
|
||||||
|
--chunk-length 32 \
|
||||||
|
--cnn-module-kernel 31 \
|
||||||
|
--left-context-length 32 \
|
||||||
|
--right-context-length 8 \
|
||||||
|
--memory-size 32 \
|
||||||
|
--jit False
|
||||||
|
|
||||||
It will generate a file exp_dir/pretrained.pt
|
It will generate a file exp_dir/pretrained.pt
|
||||||
|
|
||||||
To use the generated file with `pruned_transducer_stateless4/decode.py`,
|
To use the generated file with `conv_emformer_transducer_stateless/decode.py`,
|
||||||
you can do:
|
you can do:
|
||||||
|
|
||||||
cd /path/to/exp_dir
|
cd /path/to/exp_dir
|
||||||
ln -s pretrained.pt epoch-9999.pt
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
cd /path/to/egs/librispeech/ASR
|
cd /path/to/egs/librispeech/ASR
|
||||||
./pruned_transducer_stateless4/decode.py \
|
./conv_emformer_transducer_stateless/decode.py \
|
||||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
--exp-dir ./conv_emformer_transducer_stateless/exp \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
--use-averaged-model False
|
--use-averaged-model=False \
|
||||||
|
--num-encoder-layers 12 \
|
||||||
|
--chunk-length 32 \
|
||||||
|
--cnn-module-kernel 31 \
|
||||||
|
--left-context-length 32 \
|
||||||
|
--right-context-length 8 \
|
||||||
|
--memory-size 32
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -50,7 +64,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
from train import get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -137,6 +151,8 @@ def get_parser():
|
|||||||
"`epoch` are loaded for averaging. ",
|
"`epoch` are loaded for averaging. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -163,8 +179,6 @@ def main():
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(
|
||||||
|
@ -281,7 +281,7 @@ def modified_beam_search(
|
|||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
batch_size = len(streams)
|
batch_size = len(streams)
|
||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
@ -874,7 +874,6 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user