diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index fd1be9cdf..a77d1b141 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -39,6 +39,8 @@ class DecodeStream(object): if decoding_graph is not None: assert device == decoding_graph.device + self.params = params + # It contains a 2-D tensors representing the feature frames. self.features: torch.Tensor = None # how many frames are processed. (before subsampling). @@ -49,11 +51,12 @@ class DecodeStream(object): # The decoding result (partial or final) of current utterance. self.hyp: List = [] + self.feature_len: int = 0 + if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size elif params.decoding_method == "fast_beam_search": # feature_len is needed to get partial results. - self.feature_len: int = 0 # The rnnt_decoding_stream for fast_beam_search. self.rnnt_decoding_stream: k2.RnntDecodingStream = ( k2.RnntDecodingStream(decoding_graph) @@ -110,7 +113,10 @@ class DecodeStream(object): + ret_chunk_size, :, ] - self.num_processed_frames += chunk_size + self.num_processed_frames += ( + chunk_size + - self.params.right_context * self.params.subsampling_factor + ) if self.num_processed_frames >= self.features.size(0): self._done = True diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 8a039d6f9..cb817eac0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -203,7 +203,9 @@ class Conformer(EncoderInterface): warmup: float = 1.0, chunk_size: int = 16, left_context: int = 64, + right_context: int = 4, simulate_streaming: bool = False, + processed_lens: Optional[Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -269,7 +271,20 @@ class Conformer(EncoderInterface): {(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)}, given {states[1].shape}.""" - src_key_padding_mask = make_pad_mask(lengths + left_context) + # src_key_padding_mask = make_pad_mask(lengths + left_context) + src_key_padding_mask = make_pad_mask(lengths) + + assert processed_lens is not None + + processed_mask = torch.arange(left_context, device=x.device).expand( + x.size(0), left_context + ) + processed_lens = processed_lens.view(x.size(0), 1) + processed_mask = (processed_lens <= processed_mask).flip(1) + + src_key_padding_mask = torch.cat( + [processed_mask, src_key_padding_mask], dim=1 + ) embed = self.encoder_embed(x) embed, pos_enc = self.encoder_pos(embed, left_context) @@ -282,8 +297,11 @@ class Conformer(EncoderInterface): warmup=warmup, states=states, left_context=left_context, + right_context=right_context, ) # (T, B, F) - + if right_context > 0: + x = x[0:-right_context, ...] + lengths -= right_context else: src_key_padding_mask = make_pad_mask(lengths) @@ -465,6 +483,7 @@ class ConformerEncoderLayer(nn.Module): src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, left_context: int = 0, + right_context: int = 0, ) -> Tensor: """ Pass the input through the encoder layer. @@ -504,7 +523,12 @@ class ConformerEncoderLayer(nn.Module): key = torch.cat([states[0], src], dim=0) val = key - states[0] = key[-left_context:, ...] + if right_context > 0: + states[0] = key[ + -(left_context + right_context) : -right_context, ... # noqa + ] + else: + states[0] = key[-left_context:, ...] # multi-headed self-attention module src_att = self.self_attn( @@ -520,7 +544,7 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, conv_cache = self.conv_module(src, states[1]) + conv, conv_cache = self.conv_module(src, states[1], right_context) states[1] = conv_cache src = src + self.dropout(conv) @@ -604,6 +628,7 @@ class ConformerEncoder(nn.Module): src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, left_context: int = 0, + right_context: int = 0, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -655,6 +680,7 @@ class ConformerEncoder(nn.Module): src_key_padding_mask=src_key_padding_mask, warmup=warmup, left_context=left_context, + right_context=right_context, ) states[0][layer_index] = cache[0] states[1][layer_index] = cache[1] @@ -1306,6 +1332,7 @@ class ConvolutionModule(nn.Module): self, x: Tensor, cache: Optional[Tensor] = None, + right_context=0, ) -> Tuple[Tensor, Tensor]: """Compute convolution module. @@ -1342,7 +1369,15 @@ class ConvolutionModule(nn.Module): ), "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) - cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa + if right_context > 0: + cache = x.permute(2, 0, 1)[ + -(self.lorder + right_context) : ( # noqa + -right_context + ), + ..., + ] + else: + cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa x = self.depthwise_conv(x) x = self.deriv_balancer2(x) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index ced5e2a3b..c61975ccc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -205,6 +205,13 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + parser.add_argument( + "--right-context", + type=int, + default=4, + help="right context can be seen during decoding (in frames after subsampling)", + ) + parser.add_argument( "--num-decode-streams", type=int, @@ -355,8 +362,8 @@ def decode_one_chunk( features.append(feat) feature_lens.append(feat_len) states.append(stream.states) + processed_feature_lens.append(stream.feature_len) if params.decoding_method == "fast_beam_search": - processed_feature_lens.append(stream.feature_len) rnnt_stream_list.append(stream.rnnt_decoding_stream) feature_lens = torch.tensor(feature_lens, device=device) @@ -364,15 +371,18 @@ def decode_one_chunk( # if T is less than 7 there will be an error in time reduction layer, # because we subsample features with ((x_len - 1) // 2 - 1) // 2 - if features.size(1) < 7: - feature_lens += 7 - features.size(1) + tail_length = 7 + params.right_context * params.subsampling_factor + if features.size(1) < tail_length: + feature_lens += tail_length - features.size(1) features = torch.cat( [ features, torch.tensor( LOG_EPS, dtype=features.dtype, device=device ).expand( - features.size(0), 7 - features.size(1), features.size(2) + features.size(0), + tail_length - features.size(1), + features.size(2), ), ], dim=1, @@ -382,6 +392,7 @@ def decode_one_chunk( torch.stack([x[0] for x in states], dim=2), torch.stack([x[1] for x in states], dim=2), ] + processed_feature_lens = torch.tensor(processed_feature_lens, device=device) # Note: states will be modified in streaming_forward. encoder_out, encoder_out_lens = model.encoder.streaming_forward( @@ -389,7 +400,10 @@ def decode_one_chunk( x_lens=feature_lens, states=states, left_context=params.left_context, + right_context=params.right_context, + processed_lens=processed_feature_lens, ) + encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": @@ -402,9 +416,6 @@ def decode_one_chunk( max_contexts=params.max_contexts, max_states=params.max_states, ) - processed_feature_lens = torch.tensor( - processed_feature_lens, device=device - ) decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) processed_lens = processed_feature_lens + encoder_out_lens hyp_tokens = fast_beam_search( @@ -418,8 +429,8 @@ def decode_one_chunk( finished_streams = [] for i in range(len(decode_streams)): decode_streams[i].states = [states[0][i], states[1][i]] + decode_streams[i].feature_len += encoder_out_lens[i] if params.decoding_method == "fast_beam_search": - decode_streams[i].feature_len += encoder_out_lens[i] decode_streams[i].hyp = hyp_tokens[i] if decode_streams[i].done: finished_streams.append(i) @@ -489,7 +500,8 @@ def decode_dataset( samples = torch.from_numpy(audio).squeeze(0) fbank = Fbank(opts) - decode_stream.set_features(fbank(samples.to(device))) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature) decode_stream.ground_truth = cut.supervisions[0].text decode_streams.append(decode_stream) @@ -541,14 +553,14 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) - store_transcripts(filename=recog_path, texts=results) + store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -602,6 +614,7 @@ def main(): # for streaming params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" params.suffix += f"-left-context-{params.left_context}" + params.suffix += f"-right-context-{params.right_context}" # for fast_beam_search if params.decoding_method == "fast_beam_search":