diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode_stream.py index ba5e80555..7df96a9af 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode_stream.py @@ -71,9 +71,9 @@ class DecodeStream(object): # encoder.streaming_forward self.done_frames: int = 0 - self.pad_length = ( - params.right_context + 2 - ) * params.subsampling_factor + 3 + # add 2 here since we will drop the first and last frames after + # the convolutional subsampling module + self.pad_length = 2 * params.subsampling_factor + 3 if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py index b3e1f04c3..6ba72ca2a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -17,15 +17,13 @@ """ Usage: -./pruned_transducer_stateless2/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --left-context 32 \ - --decode-chunk-size 8 \ - --right-context 0 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --decoding_method greedy_search \ - --num-decode-streams 1000 +./lstm_transducer_stateless/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --decoding_method greedy_search \ + --decode-chunk-size 1 \ + --num-decode-streams 1000 """ import argparse @@ -44,7 +42,7 @@ from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model +from train import get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -155,24 +153,10 @@ def get_parser(): parser.add_argument( "--decode-chunk-size", type=int, - default=16, + default=1, help="The chunk size for decoding (in frames after subsampling)", ) - parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--right-context", - type=int, - default=0, - help="right context can be seen during decoding (in frames after subsampling)", - ) - parser.add_argument( "--num-decode-streams", type=int, @@ -180,8 +164,6 @@ def get_parser(): help="The number of streams that can be decoded parallel.", ) - add_model_arguments(parser) - return parser @@ -334,7 +316,7 @@ def decode_one_chunk( # we plus 2 here because we will cut off one frame on each size of # encoder_embed output as they see invalid paddings. so we need extra 2 # frames. - tail_length = 7 + (2 + params.right_context) * params.subsampling_factor + tail_length = 7 + 2 * params.subsampling_factor if features.size(1) < tail_length: feature_lens += tail_length - features.size(1) features = torch.cat( @@ -357,13 +339,10 @@ def decode_one_chunk( ] processed_lens = torch.tensor(processed_lens, device=device) - encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( + encoder_out, encoder_out_lens, states = model.encoder.infer( x=features, x_lens=feature_lens, states=states, - left_context=params.left_context, - right_context=params.right_context, - processed_lens=processed_lens, ) encoder_out = model.joiner.encoder_proj(encoder_out) @@ -442,8 +421,8 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device + initial_states = model.encoder.get_init_states( + device=device ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. @@ -584,8 +563,6 @@ def main(): # for streaming params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - params.suffix += f"-right-context-{params.right_context}" # for fast_beam_search if params.decoding_method == "fast_beam_search": @@ -609,8 +586,6 @@ def main(): params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - # Decoding in streaming requires causal convolution - params.causal_convolution = True logging.info(params)