mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
modify pretrained.py
This commit is contained in:
parent
8f3645e5cb
commit
1138b27f16
@ -358,9 +358,9 @@ def main():
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
if params.jit_trace is True:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
logging.info("Using torch.jit.trace()")
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||
|
@ -18,16 +18,16 @@
|
||||
Usage:
|
||||
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless2/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
||||
./lstm_transducer_stateless/pretrained.py \
|
||||
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless2/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
||||
./lstm_transducer_stateless/pretrained.py \
|
||||
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
@ -35,8 +35,8 @@ Usage:
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless2/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
||||
./lstm_transducer_stateless/pretrained.py \
|
||||
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
@ -44,18 +44,18 @@ Usage:
|
||||
/path/to/bar.wav
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless2/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
||||
./lstm_transducer_stateless/pretrained.py \
|
||||
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`.
|
||||
You can also use `./lstm_transducer_stateless/exp/epoch-xx.pt`.
|
||||
|
||||
Note: ./pruned_transducer_stateless2/exp/pretrained.pt is generated by
|
||||
./pruned_transducer_stateless2/export.py
|
||||
Note: ./lstm_transducer_stateless/exp/pretrained.pt is generated by
|
||||
./lstm_transducer_stateless/export.py
|
||||
"""
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ from beam_search import (
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import get_params, get_transducer_model
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -178,6 +178,8 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -268,7 +270,7 @@ def main():
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
encoder_out, encoder_out_lens, _ = model.encoder(
|
||||
x=features, x_lens=feature_lengths
|
||||
)
|
||||
|
||||
|
@ -111,9 +111,10 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--aux-layer-period",
|
||||
type=int,
|
||||
default=3,
|
||||
default=0,
|
||||
help="""Peroid of auxiliary layers used for randomly combined during training.
|
||||
If not larger than 0 (e.g., -1), will not use the random combiner.
|
||||
If set to 0, will not use the random combiner (Default).
|
||||
You can set a positive integer to use the random combiner, e.g., 3.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -206,7 +207,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--lr-epochs",
|
||||
type=float,
|
||||
default=6,
|
||||
default=10,
|
||||
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
||||
""",
|
||||
)
|
||||
@ -270,7 +271,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=8000,
|
||||
default=4000,
|
||||
help="""Save checkpoint after processing this number of batches"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
|
Loading…
x
Reference in New Issue
Block a user