From 328ad280a40d84f44e0180e8470187635455f9d9 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 19 Apr 2022 17:58:51 +0800 Subject: [PATCH] Support state stacking and unstacking operations for emformer_pruned_transducer_stateless/emformer.py --- .../emformer.py | 75 ++++++++++++++++++- .../test_emformer.py | 39 ++++++++++ 2 files changed, 111 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py index 9eb5b966f..b6f93b4c7 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -77,6 +77,75 @@ def _gen_attention_mask_block( return torch.cat(mask_block, dim=1) +def unstack_states( + states: List[List[torch.Tensor]], +) -> List[List[List[torch.Tensor]]]: + """Unstack the emformer state corresponding to a batch of utterances + into a list of states, were the i-th entry is the state from the i-th + utterance in the batch. + + Args: + states: + A list-of-list of tensors. ``len(states)`` equals to number of + layers in the emformer. ``states[i]]`` contains the states for + the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape + ``(T, N, C)`` or a 2-D tensor of shape ``(C, N)`` + """ + batch_size = states[0][0].size(1) + num_layers = len(states) + + ans = [None] * batch_size + for i in range(batch_size): + ans[i] = [[] for _ in range(num_layers)] + + for li, layer in enumerate(states): + for s in layer: + s_list = s.unbind(dim=1) + # We will use stack(dim=1) later in stack_states() + for bi, b in enumerate(ans): + b[li].append(s_list[bi]) + return ans + + +def stack_states( + state_list: List[List[List[torch.Tensor]]], +) -> List[List[torch.Tensor]]: + """Stack list of emformer states that correspond to separate utterances + into a single emformer state so that it can be used as an input for + emformer when those utterances are formed into a batch. + + Note: + It is the inverse of :func:`unstack_states`. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the emformer model for a single utterance. + Returns: + Return a new state corresponding to a batch of utterances. + See the input argument of :func:`unstack_states` for the meaning + of the returned tensor. + """ + batch_size = len(state_list) + ans = [] + for layer in state_list[0]: + # layer is a list of tensors + if batch_size > 1: + ans.append([[s] for s in layer]) + # Note: We will stack ans[layer][s][] later to get ans[layer][s] + else: + ans.append([s.unsqueeze(1) for s in layer]) + + for b, states in enumerate(state_list[1:], 1): + for li, layer in enumerate(states): + for si, s in enumerate(layer): + ans[li][si].append(s) + if b == batch_size - 1: + ans[li][si] = torch.stack(ans[li][si], dim=1) + # We will use unbind(dim=1) later in unstack_states() + return ans + + class EmformerAttention(nn.Module): r"""Emformer layer attention module. @@ -424,9 +493,9 @@ class EmformerAttention(nn.Module): # key, value: [memory, right context, left context, uttrance] KV = ( memory.size(0) - + right_context.size(0) - + left_context_key.size(0) - + utterance.size(0) + + right_context.size(0) # noqa + + left_context_key.size(0) # noqa + + utterance.size(0) # noqa ) attention_mask = torch.zeros(Q, KV).to( dtype=torch.bool, device=utterance.device diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py index abc023bb7..ecfe24c61 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py @@ -682,6 +682,44 @@ def test_emformer_infer_batch_single_consistency(): assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0) +def test_emformer_infer_states_stack(): + from emformer import Emformer, unstack_states, stack_states + + num_features = 80 + output_dim = 1000 + chunk_length = 8 + U = chunk_length + L, R = 128, 4 + B, D = 2, 256 + num_encoder_layers = 2 + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + model = Emformer( + num_features=num_features, + output_dim=output_dim, + chunk_length=chunk_length, + subsampling_factor=4, + d_model=D, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + vgg_frontend=False, + ) + + x = torch.randn(B, U + R + 3, num_features) + x_lens = torch.full((B, ), U + R + 3) + logits, output_lengths, states = model.infer(x, x_lens,) + states2 = stack_states(unstack_states(states)) + + for ss, ss2 in zip(states, states2): + for s, s2 in zip(ss, ss2): + assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}" + + if __name__ == "__main__": test_emformer_attention_forward() test_emformer_attention_infer() @@ -695,3 +733,4 @@ if __name__ == "__main__": test_emformer_layer_forward_infer_consistency() test_emformer_encoder_forward_infer_consistency() test_emformer_infer_batch_single_consistency() + test_emformer_infer_states_stack()