modify export.py

This commit is contained in:
yaozengwei 2022-06-11 22:03:19 +08:00
parent 1a9af61497
commit 626a2912ca
2 changed files with 26 additions and 13 deletions

View File

@ -20,28 +20,42 @@
# to a single one using model averaging.
"""
Usage:
./pruned_transducer_stateless4/export.py \
--exp-dir ./pruned_transducer_stateless4/exp \
./conv_emformer_transducer_stateless/export.py \
--exp-dir ./conv_emformer_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
--epoch 30 \
--avg 10 \
--use-averaged-model=True \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32 \
--jit False
It will generate a file exp_dir/pretrained.pt
To use the generated file with `pruned_transducer_stateless4/decode.py`,
To use the generated file with `conv_emformer_transducer_stateless/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./pruned_transducer_stateless4/decode.py \
--exp-dir ./pruned_transducer_stateless4/exp \
./conv_emformer_transducer_stateless/decode.py \
--exp-dir ./conv_emformer_transducer_stateless/exp \
--epoch 9999 \
--avg 1 \
--max-duration 100 \
--bpe-model data/lang_bpe_500/bpe.model \
--use-averaged-model False
--use-averaged-model=False \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32
"""
import argparse
@ -50,7 +64,7 @@ from pathlib import Path
import sentencepiece as spm
import torch
from train import get_params, get_transducer_model
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -137,6 +151,8 @@ def get_parser():
"`epoch` are loaded for averaging. ",
)
add_model_arguments(parser)
return parser
@ -163,8 +179,6 @@ def main():
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(

View File

@ -281,7 +281,7 @@ def modified_beam_search(
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
device = next(model.parameters()).device
batch_size = len(streams)
T = encoder_out.size(1)
@ -874,7 +874,6 @@ def main():
)
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)