changes codes for fast_beam_search and export cpu jit

This commit is contained in:
luomingshuang 2022-07-22 10:26:19 +08:00
parent 6d77f4c239
commit bd043b0ff0
2 changed files with 5 additions and 4 deletions

View File

@ -46,7 +46,7 @@ import logging
from pathlib import Path
import torch
from train import get_params, get_transducer_model
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
@ -78,7 +78,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="pruned_transducer_stateless5/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
@ -106,6 +106,7 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
@ -134,7 +135,6 @@ def main():
model = get_transducer_model(params)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:

View File

@ -227,7 +227,6 @@ def greedy_search(
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
# print("encoder_out shape: ", current_encoder_out.shape, "decoder_out shape: ", decoder_out.shape)
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
@ -278,6 +277,7 @@ def fast_beam_search(
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
@ -288,6 +288,7 @@ def fast_beam_search(
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)