modify pretrained.py

This commit is contained in:
yaozengwei 2022-08-10 17:16:09 +08:00
parent 8f3645e5cb
commit 1138b27f16
3 changed files with 21 additions and 18 deletions

View File

@ -358,9 +358,9 @@ def main():
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
if params.jit_trace is True:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.trace()")
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)

View File

@ -18,16 +18,16 @@
Usage:
(1) greedy search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
@ -35,8 +35,8 @@ Usage:
/path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
@ -44,18 +44,18 @@ Usage:
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`.
You can also use `./lstm_transducer_stateless/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless2/exp/pretrained.pt is generated by
./pruned_transducer_stateless2/export.py
Note: ./lstm_transducer_stateless/exp/pretrained.pt is generated by
./lstm_transducer_stateless/export.py
"""
@ -77,7 +77,7 @@ from beam_search import (
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from train import add_model_arguments, get_params, get_transducer_model
def get_parser():
@ -178,6 +178,8 @@ def get_parser():
""",
)
add_model_arguments(parser)
return parser
@ -268,7 +270,7 @@ def main():
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
encoder_out, encoder_out_lens, _ = model.encoder(
x=features, x_lens=feature_lengths
)

View File

@ -111,9 +111,10 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--aux-layer-period",
type=int,
default=3,
default=0,
help="""Peroid of auxiliary layers used for randomly combined during training.
If not larger than 0 (e.g., -1), will not use the random combiner.
If set to 0, will not use the random combiner (Default).
You can set a positive integer to use the random combiner, e.g., 3.
""",
)
@ -206,7 +207,7 @@ def get_parser():
parser.add_argument(
"--lr-epochs",
type=float,
default=6,
default=10,
help="""Number of epochs that affects how rapidly the learning rate decreases.
""",
)
@ -270,7 +271,7 @@ def get_parser():
parser.add_argument(
"--save-every-n",
type=int,
default=8000,
default=4000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename