diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py index 345792a3c..3630cebeb 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py @@ -46,7 +46,7 @@ import logging from pathlib import Path import torch -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.lexicon import Lexicon @@ -78,7 +78,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="pruned_transducer_stateless5/exp", help="""It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved """, @@ -106,6 +106,7 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) + add_model_arguments(parser) return parser @@ -134,7 +135,6 @@ def main(): model = get_transducer_model(params) model.to(device) - if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 03bd45d20..5bfb20bdf 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -227,7 +227,6 @@ def greedy_search( for t in range(T): # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - # print("encoder_out shape: ", current_encoder_out.shape, "decoder_out shape: ", decoder_out.shape) logits = model.joiner( current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1), @@ -278,6 +277,7 @@ def fast_beam_search( contexts = contexts.to(torch.int64) # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) # current_encoder_out is of shape # (shape.NumElements(), 1, joiner_dim) # fmt: off @@ -288,6 +288,7 @@ def fast_beam_search( logits = model.joiner( current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1), + project_input=False, ) logits = logits.squeeze(1).squeeze(1) log_probs = logits.log_softmax(dim=-1)