modify export.py

This commit is contained in:
yaozengwei 2022-06-11 22:03:19 +08:00
parent 1a9af61497
commit 626a2912ca
2 changed files with 26 additions and 13 deletions

View File

@ -20,28 +20,42 @@
# to a single one using model averaging. # to a single one using model averaging.
""" """
Usage: Usage:
./pruned_transducer_stateless4/export.py \ ./conv_emformer_transducer_stateless/export.py \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./conv_emformer_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \ --epoch 30 \
--avg 10 --avg 10 \
--use-averaged-model=True \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32 \
--jit False
It will generate a file exp_dir/pretrained.pt It will generate a file exp_dir/pretrained.pt
To use the generated file with `pruned_transducer_stateless4/decode.py`, To use the generated file with `conv_emformer_transducer_stateless/decode.py`,
you can do: you can do:
cd /path/to/exp_dir cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR cd /path/to/egs/librispeech/ASR
./pruned_transducer_stateless4/decode.py \ ./conv_emformer_transducer_stateless/decode.py \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./conv_emformer_transducer_stateless/exp \
--epoch 9999 \ --epoch 9999 \
--avg 1 \ --avg 1 \
--max-duration 100 \ --max-duration 100 \
--bpe-model data/lang_bpe_500/bpe.model \ --bpe-model data/lang_bpe_500/bpe.model \
--use-averaged-model False --use-averaged-model=False \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32
""" """
import argparse import argparse
@ -50,7 +64,7 @@ from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from train import get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -137,6 +151,8 @@ def get_parser():
"`epoch` are loaded for averaging. ", "`epoch` are loaded for averaging. ",
) )
add_model_arguments(parser)
return parser return parser
@ -163,8 +179,6 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(

View File

@ -281,7 +281,7 @@ def modified_beam_search(
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = next(model.parameters()).device
batch_size = len(streams) batch_size = len(streams)
T = encoder_out.size(1) T = encoder_out.size(1)
@ -874,7 +874,6 @@ def main():
) )
model.eval() model.eval()
model.device = device
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)