From a1cbe1fd9cfcc3e0e6f24b05eb3462b18a8cd17d Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 13 Jun 2022 12:47:35 +0800 Subject: [PATCH] fix doc of stack and unstack, test case with batch_size=1 --- .../emformer.py | 25 +++++--- .../test_emformer.py | 61 ++++++++++--------- 2 files changed, 49 insertions(+), 37 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index eaadaf052..46993da48 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -43,15 +43,22 @@ def unstack_states( states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] ) -> List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]: """Unstack the emformer state corresponding to a batch of utterances - into a list of states, were the i-th entry is the state from the i-th + into a list of states, where the i-th entry is the state from the i-th utterance in the batch. Args: states: - A list of tuples. - ``states[i][0]`` is the attention caches of i-th utterance. - ``states[i][1]`` is the convolution caches of i-th utterance. - ``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa + A tuple of 2 elements. + ``states[0]`` is the attention caches of a batch of utterance. + ``states[1]`` is the convolution caches of a batch of utterance. + ``len(states[0])`` and ``len(states[1])`` both eqaul to number of layers. # noqa + + Returns: + A list of states. + ``states[i]`` is a tuple of 2 elements of i-th utterance. + ``states[i][0]`` is the attention caches of i-th utterance. + ``states[i][1]`` is the convolution caches of i-th utterance. + ``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa """ attn_caches, conv_caches = states @@ -85,7 +92,6 @@ def unstack_states( def stack_states( state_list: List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]] ) -> Tuple[List[List[torch.Tensor]], List[torch.Tensor]]: - # TODO: modify doc """Stack list of emformer states that correspond to separate utterances into a single emformer state so that it can be used as an input for emformer when those utterances are formed into a batch. @@ -97,8 +103,13 @@ def stack_states( state_list: Each element in state_list corresponding to the internal state of the emformer model for a single utterance. + ``states[i]`` is a tuple of 2 elements of i-th utterance. + ``states[i][0]`` is the attention caches of i-th utterance. + ``states[i][1]`` is the convolution caches of i-th utterance. + ``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa + Returns: - Return a new state corresponding to a batch of utterances. + A new state corresponding to a batch of utterances. See the input argument of :func:`unstack_states` for the meaning of the returned tensor. """ diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py index 99fb0a877..f80f4e367 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py @@ -85,7 +85,6 @@ def test_state_stack_unstack(): left_context_length = 32 right_context_length = 8 memory_size = 32 - batch_size = 2 model = Emformer( num_features=num_features, @@ -98,40 +97,42 @@ def test_state_stack_unstack(): right_context_length=right_context_length, memory_size=memory_size, ) - attn_caches = [ - [ - torch.zeros(memory_size, batch_size, encoder_dim), - torch.zeros(left_context_length // 4, batch_size, encoder_dim), - torch.zeros( - left_context_length // 4, - batch_size, - encoder_dim, - ), + + for batch_size in [1, 2]: + attn_caches = [ + [ + torch.zeros(memory_size, batch_size, encoder_dim), + torch.zeros(left_context_length // 4, batch_size, encoder_dim), + torch.zeros( + left_context_length // 4, + batch_size, + encoder_dim, + ), + ] + for _ in range(num_encoder_layers) ] - for _ in range(num_encoder_layers) - ] - conv_caches = [ - torch.zeros(batch_size, encoder_dim, kernel_size - 1) - for _ in range(num_encoder_layers) - ] - states = [attn_caches, conv_caches] - x = torch.randn(batch_size, 23, num_features) - x_lens = torch.full((batch_size,), 23) - num_processed_frames = torch.full((batch_size,), 0) - y, y_lens, states = model.infer( - x, x_lens, num_processed_frames=num_processed_frames, states=states - ) + conv_caches = [ + torch.zeros(batch_size, encoder_dim, kernel_size - 1) + for _ in range(num_encoder_layers) + ] + states = [attn_caches, conv_caches] + x = torch.randn(batch_size, 23, num_features) + x_lens = torch.full((batch_size,), 23) + num_processed_frames = torch.full((batch_size,), 0) + y, y_lens, states = model.infer( + x, x_lens, num_processed_frames=num_processed_frames, states=states + ) - state_list = unstack_states(states) - states2 = stack_states(state_list) + state_list = unstack_states(states) + states2 = stack_states(state_list) - for ss, ss2 in zip(states[0], states2[0]): - for s, s2 in zip(ss, ss2): + for ss, ss2 in zip(states[0], states2[0]): + for s, s2 in zip(ss, ss2): + assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}" + + for s, s2 in zip(states[1], states2[1]): assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}" - for s, s2 in zip(states[1], states2[1]): - assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}" - def test_torchscript_consistency_infer(): r"""Verify that scripting Emformer does not change the behavior of method `infer`.""" # noqa