diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 9576eca52..eaadaf052 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -42,15 +42,16 @@ LOG_EPSILON = math.log(1e-10) def unstack_states( states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] ) -> List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]: - # TODO: modify doc """Unstack the emformer state corresponding to a batch of utterances into a list of states, were the i-th entry is the state from the i-th utterance in the batch. Args: states: - A list-of-list of tensors. - ``len(states[0])`` and ``len(states[1])`` eqaul to number of layers. + A list of tuples. + ``states[i][0]`` is the attention caches of i-th utterance. + ``states[i][1]`` is the convolution caches of i-th utterance. + ``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa """ attn_caches, conv_caches = states @@ -146,7 +147,7 @@ class ConvolutionModule(nn.Module): right_context_length (int): Length of right context. channels (int): - The number of channels of conv layers. + The number of input channels and output channels of conv layers. kernel_size (int): Kernerl size of conv layers. bias (bool): @@ -162,9 +163,9 @@ class ConvolutionModule(nn.Module): bias: bool = True, ) -> None: """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 + super().__init__() + # kernerl_size should be an odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0, kernel_size self.chunk_length = chunk_length self.right_context_length = right_context_length diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py index 814bbb49f..31ad3f50a 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py @@ -53,7 +53,7 @@ class Stream(object): # Initailize zero states. self.init_states(params) - # It use different attributes for different decoding methods. + # It uses different attributes for different decoding methods. self.context_size = params.context_size self.decoding_method = params.decoding_method if params.decoding_method == "greedy_search": @@ -72,7 +72,7 @@ class Stream(object): self.rnnt_decoding_stream: k2.RnntDecodingStream = ( k2.RnntDecodingStream(decoding_graph) ) - self.hyp: List[int] = None + self.hyp: Optional[List[int]] = None else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -134,10 +134,14 @@ class Stream(object): ) for _ in range(params.num_encoder_layers) ] - self.states = [attn_caches, conv_caches] + self.states = (attn_caches, conv_caches) - def get_feature_chunk(self) -> Tuple[torch.Tensor, int]: - """Get a chunk of feature frames.""" + def get_feature_chunk(self) -> torch.Tensor: + """Get a chunk of feature frames. + + Returns: + A tensor of shape (ret_length, feature_dim). + """ update_length = min( self.num_frames - self.num_processed_frames, self.chunk_length ) @@ -153,11 +157,11 @@ class Stream(object): if self.num_processed_frames >= self.num_frames: self._done = True - return ret_feature, ret_length + return ret_feature @property def done(self) -> bool: - """Return True if `self.input_finished()` has been invoked""" + """Return True if all feature frames are processed.""" return self._done def decoding_result(self) -> List[int]: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 2b024fa34..4fac405b0 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -245,8 +245,9 @@ def greedy_search( model: nn.Module, encoder_out: torch.Tensor, streams: List[Stream], -) -> List[List[int]]: +) -> None: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: model: The transducer model. @@ -270,10 +271,9 @@ def greedy_search( device=device, dtype=torch.int64, ) - # decoder_out is of shape (N, decoder_out_dim) + # decoder_out is of shape (batch_size, 1, decoder_out_dim) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) - # logging.info(f"decoder_out shape : {decoder_out.shape}") for t in range(T): # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) @@ -427,7 +427,7 @@ def fast_beam_search_one_best( beam: float, max_states: int, max_contexts: int, -) -> List[List[int]]: +) -> None: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using modified beam search, and then @@ -449,8 +449,6 @@ def fast_beam_search_one_best( Max states per stream per frame. max_contexts: Max contexts pre stream per frame. - Returns: - Return the decoded result. """ assert encoder_out.ndim == 3 @@ -543,7 +541,8 @@ def decode_one_chunk( # before calling `stream.get_feature_chunk()` # since `stream.num_processed_frames` would be updated num_processed_frames_list.append(stream.num_processed_frames) - feature, feature_len = stream.get_feature_chunk() + feature = stream.get_feature_chunk() + feature_len = feature.size(0) feature_list.append(feature) feature_len_list.append(feature_len) state_list.append(stream.states) @@ -809,7 +808,6 @@ def main(): "fast_beam_search", "modified_beam_search", ) - # Note: params.decoding_method is currently not used. params.res_dir = params.exp_dir / "streaming" / params.decoding_method if params.iter > 0: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py index cde4ab4a5..55ee487a3 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py @@ -1,9 +1,64 @@ +#!/usr/bin/env python3 + import torch +from emformer import Emformer, stack_states, unstack_states + + +def test_convolution_module_forward(): + from emformer import ConvolutionModule + + B, D = 2, 256 + chunk_length = 4 + right_context_length = 2 + num_chunks = 3 + U = num_chunks * chunk_length + R = num_chunks * right_context_length + kernel_size = 31 + conv_module = ConvolutionModule( + chunk_length, + right_context_length, + D, + kernel_size, + ) + + utterance = torch.randn(U, B, D) + right_context = torch.randn(R, B, D) + + utterance, right_context = conv_module(utterance, right_context) + assert utterance.shape == (U, B, D) + assert right_context.shape == (R, B, D) + + +def test_convolution_module_infer(): + from emformer import ConvolutionModule + + B, D = 2, 256 + chunk_length = 4 + right_context_length = 2 + num_chunks = 1 + U = num_chunks * chunk_length + R = num_chunks * right_context_length + kernel_size = 31 + conv_module = ConvolutionModule( + chunk_length, + right_context_length, + D, + kernel_size, + ) + + utterance = torch.randn(U, B, D) + right_context = torch.randn(R, B, D) + cache = torch.randn(B, D, kernel_size - 1) + + utterance, right_context, new_cache = conv_module.infer( + utterance, right_context, cache + ) + assert utterance.shape == (U, B, D) + assert right_context.shape == (R, B, D) + assert new_cache.shape == (B, D, kernel_size - 1) def test_state_stack_unstack(): - from emformer import Emformer, stack_states, unstack_states - num_features = 80 chunk_length = 32 encoder_dim = 512 @@ -62,8 +117,6 @@ def test_state_stack_unstack(): def test_torchscript_consistency_infer(): r"""Verify that scripting Emformer does not change the behavior of method `infer`.""" # noqa - from emformer import Emformer - num_features = 80 chunk_length = 32 encoder_dim = 512 @@ -118,5 +171,7 @@ def test_torchscript_consistency_infer(): if __name__ == "__main__": + test_convolution_module_forward() + test_convolution_module_infer() test_state_stack_unstack() test_torchscript_consistency_infer() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index 507f19c1b..106f3e511 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -619,7 +619,7 @@ def compute_loss( warmup: float = 1.0, ) -> Tuple[Tensor, MetricsTracker]: """ - Compute CTC loss given the model and its inputs. + Compute RNN-T loss given the model and its inputs. Args: params: