From 626a2912cac839e8430fe51bb9f470cb7dcf3db5 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sat, 11 Jun 2022 22:03:19 +0800 Subject: [PATCH] modify export.py --- .../export.py | 36 +++++++++++++------ .../streaming_decode.py | 3 +- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py index 8f64b5d64..4930881ea 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py @@ -20,28 +20,42 @@ # to a single one using model averaging. """ Usage: -./pruned_transducer_stateless4/export.py \ - --exp-dir ./pruned_transducer_stateless4/exp \ +./conv_emformer_transducer_stateless/export.py \ + --exp-dir ./conv_emformer_transducer_stateless/exp \ --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 \ + --jit False It will generate a file exp_dir/pretrained.pt -To use the generated file with `pruned_transducer_stateless4/decode.py`, +To use the generated file with `conv_emformer_transducer_stateless/decode.py`, you can do: cd /path/to/exp_dir ln -s pretrained.pt epoch-9999.pt cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless4/decode.py \ - --exp-dir ./pruned_transducer_stateless4/exp \ + ./conv_emformer_transducer_stateless/decode.py \ + --exp-dir ./conv_emformer_transducer_stateless/exp \ --epoch 9999 \ --avg 1 \ --max-duration 100 \ --bpe-model data/lang_bpe_500/bpe.model \ - --use-averaged-model False + --use-averaged-model=False \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 """ import argparse @@ -50,7 +64,7 @@ from pathlib import Path import sentencepiece as spm import torch -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -137,6 +151,8 @@ def get_parser(): "`epoch` are loaded for averaging. ", ) + add_model_arguments(parser) + return parser @@ -163,8 +179,6 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - model.to(device) - if not params.use_averaged_model: if params.iter > 0: filenames = find_checkpoints( diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 62ba144b4..2ff1189fc 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -281,7 +281,7 @@ def modified_beam_search( blank_id = model.decoder.blank_id context_size = model.decoder.context_size - device = model.device + device = next(model.parameters()).device batch_size = len(streams) T = encoder_out.size(1) @@ -874,7 +874,6 @@ def main(): ) model.eval() - model.device = device if params.decoding_method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)