diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index 050bef60a..e96277ea7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import List, Optional, Tuple import k2 @@ -45,6 +46,7 @@ class DecodeStream(object): assert device == decoding_graph.device self.params = params + self.LOG_EPS = math.log(1e-10) self.states = initial_states @@ -52,6 +54,7 @@ class DecodeStream(object): self.features: torch.Tensor = None # how many frames have been processed. (before subsampling). # we only modify this value in `func:get_feature_frames`. + self.num_frames = 0 self.num_processed_frames: int = 0 self._done: bool = False # The transcript of current utterance. @@ -62,7 +65,11 @@ class DecodeStream(object): # how many frames have been processed, after subsampling (i.e. a # cumulative sum of the second return value of # encoder.streaming_forward - self.feature_len: int = 0 + self.done_frames: int = 0 + + self.pad_length = ( + params.right_context + 2 + ) * params.subsampling_factor + 3 if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size @@ -86,27 +93,32 @@ class DecodeStream(object): features: torch.Tensor, ) -> None: """Set features tensor of current utterance.""" - self.features = features + assert features.dim() == 2, features.dim() + self.num_frames = features.size(0) + # tail padding + self.features = torch.nn.functional.pad( + features, + (0, 0, 0, self.pad_length), + mode="constant", + value=self.LOG_EPS, + ) def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: """Consume chunk_size frames of features""" # plus 3 here because we subsampling features with # lengths = ((x_lens - 1) // 2 - 1) // 2 - ret_chunk_size = min( - self.features.size(0) - self.num_processed_frames, chunk_size + 3 + update_length = min( + self.num_frames - self.num_processed_frames, chunk_size ) + ret_length = update_length + self.pad_length + ret_features = self.features[ self.num_processed_frames : self.num_processed_frames # noqa - + ret_chunk_size, - :, + + ret_length ] - self.num_processed_frames += ( - chunk_size - - 2 * self.params.subsampling_factor - - self.params.right_context * self.params.subsampling_factor - ) - if self.num_processed_frames >= self.features.size(0): + self.num_processed_frames += update_length + if self.num_processed_frames >= self.num_frames: self._done = True - return ret_features, ret_chunk_size + return ret_features, ret_length diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index f61d10b20..87a35082a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -341,20 +341,16 @@ def decode_one_chunk( states = [] rnnt_stream_list = [] - processed_feature_lens = [] + processed_lens = [] for stream in decode_streams: - # we plus 2 here because we will cut off one frame on each size of - # encoder_embed output as they see invalid paddings. so we need extra 2 - # frames. feat, feat_len = stream.get_feature_frames( - (params.decode_chunk_size + 2 + params.right_context) - * params.subsampling_factor + params.decode_chunk_size * params.subsampling_factor ) features.append(feat) feature_lens.append(feat_len) states.append(stream.states) - processed_feature_lens.append(stream.feature_len) + processed_lens.append(stream.done_frames) if params.decoding_method == "fast_beam_search": rnnt_stream_list.append(stream.rnnt_decoding_stream) @@ -388,16 +384,15 @@ def decode_one_chunk( torch.stack([x[1] for x in states], dim=2), ] - processed_feature_lens = torch.tensor(processed_feature_lens, device=device) + processed_lens = torch.tensor(processed_lens, device=device) - # Note: states will be modified in streaming_forward. encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( x=features, x_lens=feature_lens, states=states, left_context=params.left_context, right_context=params.right_context, - processed_lens=processed_feature_lens, + processed_lens=processed_lens, ) if params.decoding_method == "greedy_search": @@ -411,7 +406,7 @@ def decode_one_chunk( max_states=params.max_states, ) decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) - processed_lens = processed_feature_lens + encoder_out_lens + processed_lens = processed_lens + encoder_out_lens hyp_tokens = fast_beam_search( model, encoder_out, processed_lens, decoding_streams ) @@ -423,7 +418,7 @@ def decode_one_chunk( finished_streams = [] for i in range(len(decode_streams)): decode_streams[i].states = [states[0][i], states[1][i]] - decode_streams[i].feature_len += encoder_out_lens[i] + decode_streams[i].done_frames += encoder_out_lens[i] if params.decoding_method == "fast_beam_search": decode_streams[i].hyp = hyp_tokens[i] if decode_streams[i].done: @@ -469,7 +464,7 @@ def decode_dataset( opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 - log_interval = 50 + log_interval = 100 decode_results = [] # Contain decode streams currently running. @@ -557,6 +552,9 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + # sort results so we can easily compare the difference between two + # recognition results + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index 6072a288e..8b5650ad7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -347,20 +347,16 @@ def decode_one_chunk( states = [] rnnt_stream_list = [] - processed_feature_lens = [] + processed_lens = [] for stream in decode_streams: - # we plus 2 here because we will cut off one frame on each size of - # encoder_embed output as they see invalid paddings. so we need extra 2 - # frames. feat, feat_len = stream.get_feature_frames( - (params.decode_chunk_size + 2 + params.right_context) - * params.subsampling_factor + params.decode_chunk_size * params.subsampling_factor ) features.append(feat) feature_lens.append(feat_len) states.append(stream.states) - processed_feature_lens.append(stream.feature_len) + processed_lens.append(stream.done_frames) if params.decoding_method == "fast_beam_search": rnnt_stream_list.append(stream.rnnt_decoding_stream) @@ -369,6 +365,9 @@ def decode_one_chunk( # if T is less than 7 there will be an error in time reduction layer, # because we subsample features with ((x_len - 1) // 2 - 1) // 2 + # we plus 2 here because we will cut off one frame on each size of + # encoder_embed output as they see invalid paddings. so we need extra 2 + # frames. tail_length = 7 + (2 + params.right_context) * params.subsampling_factor if features.size(1) < tail_length: feature_lens += tail_length - features.size(1) @@ -390,7 +389,7 @@ def decode_one_chunk( torch.stack([x[0] for x in states], dim=2), torch.stack([x[1] for x in states], dim=2), ] - processed_feature_lens = torch.tensor(processed_feature_lens, device=device) + processed_lens = torch.tensor(processed_lens, device=device) encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( x=features, @@ -398,7 +397,7 @@ def decode_one_chunk( states=states, left_context=params.left_context, right_context=params.right_context, - processed_lens=processed_feature_lens, + processed_lens=processed_lens, ) encoder_out = model.joiner.encoder_proj(encoder_out) @@ -414,7 +413,7 @@ def decode_one_chunk( max_states=params.max_states, ) decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) - processed_lens = processed_feature_lens + encoder_out_lens + processed_lens = processed_lens + encoder_out_lens hyp_tokens = fast_beam_search( model, encoder_out, processed_lens, decoding_streams ) @@ -426,7 +425,7 @@ def decode_one_chunk( finished_streams = [] for i in range(len(decode_streams)): decode_streams[i].states = [states[0][i], states[1][i]] - decode_streams[i].feature_len += encoder_out_lens[i] + decode_streams[i].done_frames += encoder_out_lens[i] if params.decoding_method == "fast_beam_search": decode_streams[i].hyp = hyp_tokens[i] if decode_streams[i].done: @@ -561,7 +560,10 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) - store_transcripts(filename=recog_path, texts=sorted(results)) + # sort results so we can easily compare the difference between two + # recognition results + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 1117acbf4..18776c763 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -348,20 +348,16 @@ def decode_one_chunk( states = [] rnnt_stream_list = [] - processed_feature_lens = [] + processed_lens = [] for stream in decode_streams: - # we plus 2 here because we will cut off one frame on each size of - # encoder_embed output as they see invalid paddings. so we need extra 2 - # frames. feat, feat_len = stream.get_feature_frames( - (params.decode_chunk_size + 2 + params.right_context) - * params.subsampling_factor + params.decode_chunk_size * params.subsampling_factor ) features.append(feat) feature_lens.append(feat_len) states.append(stream.states) - processed_feature_lens.append(stream.feature_len) + processed_lens.append(stream.done_frames) if params.decoding_method == "fast_beam_search": rnnt_stream_list.append(stream.rnnt_decoding_stream) @@ -394,7 +390,7 @@ def decode_one_chunk( torch.stack([x[0] for x in states], dim=2), torch.stack([x[1] for x in states], dim=2), ] - processed_feature_lens = torch.tensor(processed_feature_lens, device=device) + processed_lens = torch.tensor(processed_lens, device=device) encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( x=features, @@ -402,7 +398,7 @@ def decode_one_chunk( states=states, left_context=params.left_context, right_context=params.right_context, - processed_lens=processed_feature_lens, + processed_lens=processed_lens, ) encoder_out = model.joiner.encoder_proj(encoder_out) @@ -418,7 +414,7 @@ def decode_one_chunk( max_states=params.max_states, ) decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) - processed_lens = processed_feature_lens + encoder_out_lens + processed_lens = processed_lens + encoder_out_lens hyp_tokens = fast_beam_search( model, encoder_out, processed_lens, decoding_streams ) @@ -430,7 +426,7 @@ def decode_one_chunk( finished_streams = [] for i in range(len(decode_streams)): decode_streams[i].states = [states[0][i], states[1][i]] - decode_streams[i].feature_len += encoder_out_lens[i] + decode_streams[i].done_frames += encoder_out_lens[i] if params.decoding_method == "fast_beam_search": decode_streams[i].hyp = hyp_tokens[i] if decode_streams[i].done: @@ -565,7 +561,8 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) - store_transcripts(filename=recog_path, texts=sorted(results)) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index 7d072079e..2ca00ec40 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -359,20 +359,16 @@ def decode_one_chunk( states = [] rnnt_stream_list = [] - processed_feature_lens = [] + processed_lens = [] for stream in decode_streams: - # we plus 2 here because we will cut off one frame on each size of - # encoder_embed output as they see invalid paddings. so we need extra 2 - # frames. feat, feat_len = stream.get_feature_frames( - (params.decode_chunk_size + 2 + params.right_context) - * params.subsampling_factor + params.decode_chunk_size * params.subsampling_factor ) features.append(feat) feature_lens.append(feat_len) states.append(stream.states) - processed_feature_lens.append(stream.feature_len) + processed_lens.append(stream.done_frames) if params.decoding_method == "fast_beam_search": rnnt_stream_list.append(stream.rnnt_decoding_stream) @@ -405,7 +401,7 @@ def decode_one_chunk( torch.stack([x[0] for x in states], dim=2), torch.stack([x[1] for x in states], dim=2), ] - processed_feature_lens = torch.tensor(processed_feature_lens, device=device) + processed_lens = torch.tensor(processed_lens, device=device) encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( x=features, @@ -413,7 +409,7 @@ def decode_one_chunk( states=states, left_context=params.left_context, right_context=params.right_context, - processed_lens=processed_feature_lens, + processed_lens=processed_lens, ) encoder_out = model.joiner.encoder_proj(encoder_out) @@ -429,7 +425,7 @@ def decode_one_chunk( max_states=params.max_states, ) decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) - processed_lens = processed_feature_lens + encoder_out_lens + processed_lens = processed_lens + encoder_out_lens hyp_tokens = fast_beam_search( model, encoder_out, processed_lens, decoding_streams ) @@ -441,7 +437,7 @@ def decode_one_chunk( finished_streams = [] for i in range(len(decode_streams)): decode_streams[i].states = [states[0][i], states[1][i]] - decode_streams[i].feature_len += encoder_out_lens[i] + decode_streams[i].done_frames += encoder_out_lens[i] if params.decoding_method == "fast_beam_search": decode_streams[i].hyp = hyp_tokens[i] if decode_streams[i].done: @@ -576,7 +572,8 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) - store_transcripts(filename=recog_path, texts=sorted(results)) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned