From 5dc5f8305ab0ebafb985addd6512e2aaba4ab68b Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 13 May 2022 17:07:40 +0800 Subject: [PATCH] add emformer attention module --- .../emformer.py | 482 ++++++++++++++++++ .../test_emformer.py | 91 ++++ 2 files changed, 573 insertions(+) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 249167041..24ee8b0be 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -151,3 +151,485 @@ class RelPositionalEncoding(torch.nn.Module): self.gen_pe_negative() pos_emb = self.get_pe(pos_len, neg_len, x.device, x.dtype) return self.dropout(x), self.dropout(pos_emb) + + +class EmformerAttention(nn.Module): + r"""Emformer layer attention module. + + Args: + embed_dim (int): + Embedding dimension. + nhead (int): + Number of attention heads in each Emformer layer. + dropout (float, optional): + Dropout probability. (Default: 0.0) + tanh_on_mem (bool, optional): + If ``True``, applies tanh to memory elements. (Default: ``False``) + negative_inf (float, optional): + Value to use for negative infinity in attention weights. (Default: -1e8) + """ + + def __init__( + self, + embed_dim: int, + nhead: int, + dropout: float = 0.0, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + if embed_dim % nhead != 0: + raise ValueError( + f"embed_dim ({embed_dim}) is not a multiple of" + f"nhead ({nhead})." + ) + + self.embed_dim = embed_dim + self.nhead = nhead + self.tanh_on_mem = tanh_on_mem + self.negative_inf = negative_inf + self.head_dim = embed_dim // nhead + self.dropout = dropout + + self.emb_to_key_value = ScaledLinear( + embed_dim, 2 * embed_dim, bias=True + ) + self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) + self.out_proj = ScaledLinear( + embed_dim, embed_dim, bias=True, initial_scale=0.25 + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa + self.pos_bias_u = nn.Parameter(torch.Tensor(nhead, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(nhead, self.head_dim)) + self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) + self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) + + self._reset_parameters() + + def _pos_bias_u(self): + return self.pos_bias_u * self.pos_bias_u_scale.exp() + + def _pos_bias_v(self): + return self.pos_bias_v * self.pos_bias_v_scale.exp() + + def _reset_parameters(self) -> None: + nn.init.normal_(self.pos_bias_u, std=0.01) + nn.init.normal_(self.pos_bias_v, std=0.01) + + def _gen_attention_probs( + self, + attention_weights: torch.Tensor, + attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + """Given the entire attention weights, mask out unecessary connections + and optionally with padding positions, to obtain underlying chunk-wise + attention probabilities. + + B: batch size; + Q: length of query; + KV: length of key and value. + + Args: + attention_weights (torch.Tensor): + Attention weights computed on the entire concatenated tensor + with shape (B * nhead, Q, KV). + attention_mask (torch.Tensor): + Mask tensor where chunk-wise connections are filled with `False`, + and other unnecessary connections are filled with `True`, + with shape (Q, KV). + padding_mask (torch.Tensor, optional): + Mask tensor where the padding positions are fill with `True`, + and other positions are filled with `False`, with shapa `(B, KV)`. + + Returns: + A tensor of shape (B * nhead, Q, KV). + """ + attention_weights_float = attention_weights.float() + attention_weights_float = attention_weights_float.masked_fill( + attention_mask.unsqueeze(0), self.negative_inf + ) + if padding_mask is not None: + Q = attention_weights.size(1) + B = attention_weights.size(0) // self.nhead + attention_weights_float = attention_weights_float.view( + B, self.nhead, Q, -1 + ) + attention_weights_float = attention_weights_float.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + self.negative_inf, + ) + attention_weights_float = attention_weights_float.view( + B * self.nhead, Q, -1 + ) + + attention_probs = nn.functional.softmax( + attention_weights_float, dim=-1 + ).type_as(attention_weights) + + attention_probs = nn.functional.dropout( + attention_probs, p=self.dropout, training=self.training + ) + return attention_probs + + def _rel_shift(self, x: torch.Tensor) -> torch.Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor, of shape (B, nhead, U, PE). + U is the length of query vector. + For training mode, PE = 2 * U - 1; + for inference mode, PE = L + 2 * U - 1. + + Returns: + A tensor of shape (B, nhead, U, out_len). + For non-infer mode, out_len = U; + for infer mode, out_len = L + U. + """ + B, nhead, U, PE = x.size() + B_stride = x.stride(0) + nhead_stride = x.stride(1) + U_stride = x.stride(2) + PE_stride = x.stride(3) + out_len = PE - (U - 1) + return x.as_strided( + size=(B, nhead, U, out_len), + stride=(B_stride, nhead_stride, U_stride - PE_stride, PE_stride), + storage_offset=PE_stride * (U - 1), + ) + + def _forward_impl( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + summary: torch.Tensor, + memory: torch.Tensor, + attention_mask: torch.Tensor, + pos_emb: torch.Tensor, + left_context_key: Optional[torch.Tensor] = None, + left_context_val: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Underlying chunk-wise attention implementation.""" + U, B, _ = utterance.size() + R = right_context.size(0) + M = memory.size(0) + scaling = float(self.head_dim) ** -0.5 + + # compute query with [right_context, utterance, summary]. + query = self.emb_to_query( + torch.cat([right_context, utterance, summary]) + ) + # compute key and value with [memory, right_context, utterance]. + key, value = self.emb_to_key_value( + torch.cat([memory, right_context, utterance]) + ).chunk(chunks=2, dim=2) + + if left_context_key is not None and left_context_val is not None: + # now compute key and value with + # [memory, right context, left context, uttrance] + # this is used in inference mode + key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) + value = torch.cat( + [value[: M + R], left_context_val, value[M + R :]] + ) + Q = query.size(0) + KV = key.size(0) + + reshaped_key, reshaped_value = [ + tensor.contiguous() + .view(KV, B * self.nhead, self.head_dim) + .transpose(0, 1) + for tensor in [key, value] + ] # both of shape (B * nhead, KV, head_dim) + reshaped_query = ( + query.contiguous().view(Q, B, self.nhead, self.head_dim) * scaling + ) + + # compute attention score + # first, compute attention matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa + query_with_bais_u = ( + (reshaped_query + self._pos_bias_u()) + .view(Q, B * self.nhead, self.head_dim) + .transpose(0, 1) + ) # (B * nhead, Q, head_dim) + matrix_ac = torch.bmm( + query_with_bais_u, reshaped_key.transpose(1, 2) + ) # (B * nhead, Q, KV) + + # second, compute attention matrix b and matrix d + # relative positional encoding is applied on the part of attention + # between chunk (in query) and itself as well as its left context + # (in key) + utterance_with_bais_v = ( + reshaped_query[R : R + U] + self._pos_bias_v() + ).permute(1, 2, 0, 3) + # (B, nhead, U, head_dim) + PE = pos_emb.size(0) + if left_context_key is not None and left_context_val is not None: + # inference mode + L = left_context_key.size(0) + assert PE == L + 2 * U - 1 + else: + # training mode + assert PE == 2 * U - 1 + pos_emb = ( + self.linear_pos(pos_emb) + .view(PE, self.nhead, self.head_dim) + .transpose(0, 1) + .unsqueeze(0) + ) # (1, nhead, PE, head_dim) + matrix_bd_utterance = torch.matmul( + utterance_with_bais_v, pos_emb.transpose(-2, -1) + ) # (B, nhead, U, PE) + # rel-shift operation + matrix_bd_utterance = self._rel_shift(matrix_bd_utterance) + # (B, nhead, U, U) for training mode; + # (B, nhead, U, L + U) for inference mode. + matrix_bd_utterance = matrix_bd_utterance.contiguous().view( + B * self.nhead, U, -1 + ) + matrix_bd = torch.zeros_like(matrix_ac) + matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance + + attention_weights = matrix_ac + matrix_bd + + # compute padding mask + if B == 1: + padding_mask = None + else: + padding_mask = make_pad_mask(KV - U + lengths) + + # compute attention probabilities + attention_probs = self._gen_attention_probs( + attention_weights, attention_mask, padding_mask + ) + + # compute attention outputs + attention = torch.bmm(attention_probs, reshaped_value) + assert attention.shape == (B * self.nhead, Q, self.head_dim) + attention = ( + attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + ) + + # apply output projection + outputs = self.out_proj(attention) + + output_right_context_utterance = outputs[: R + U] + output_memory = outputs[R + U :] + if self.tanh_on_mem: + output_memory = torch.tanh(output_memory) + else: + output_memory = torch.clamp(output_memory, min=-10, max=10) + + return output_right_context_utterance, output_memory, key, value + + def forward( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + summary: torch.Tensor, + memory: torch.Tensor, + attention_mask: torch.Tensor, + pos_emb: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # TODO: Modify docs. + """Forward pass for training mode. + + B: batch size; + D: embedding dimension; + R: length of the hard-copied right contexts; + U: length of full utterance; + S: length of summary vectors; + M: length of memory vectors. + + It computes a `big` attention matrix on full utterance and + then utilizes a pre-computed mask to simulate chunk-wise attention. + + It concatenates three blocks: hard-copied right contexts, + full utterance, and summary vectors, as a `big` block, + to compute the query tensor: + query = [right_context, utterance, summary], + with length Q = R + U + S. + It concatenates the three blocks: memory vectors, + hard-copied right contexts, and full utterance as another `big` block, + to compute the key and value tensors: + key & value = [memory, right_context, utterance], + with length KV = M + R + U. + Attention scores is computed with above `big` query and key. + + Then the underlying chunk-wise attention is obtained by applying + the attention mask. Suppose + c_i: chunk at index i; + r_i: right context that c_i can use; + l_i: left context that c_i can use; + m_i: past memory vectors from previous layer that c_i can use; + s_i: summary vector of c_i; + The target chunk-wise attention is: + c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key); + s_i (in query) -> l_i, c_i, r_i (in key). + + Relative positional encoding is applied on the part of attention between + utterance (in query) and utterance (in key). Actually, it is applied on + the part of attention between each chunk (in query) and itself as well + as its left context (in key), after applying the mask: + c_i -> l_i, c_i. + + Args: + utterance (torch.Tensor): + Full utterance frames, with shape (U, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing + number of valid frames for i-th batch element in utterance. + right_context (torch.Tensor): + Hard-copied right context frames, with shape (R, B, D), + where R = num_chunks * right_context_length + summary (torch.Tensor): + Summary elements with shape (S, B, D), where S = num_chunks. + It is an empty tensor without using memory. + memory (torch.Tensor): + Memory elements, with shape (M, B, D), where M = num_chunks - 1. + It is an empty tensor without using memory. + attention_mask (torch.Tensor): + Pre-computed attention mask to simulate underlying chunk-wise + attention, with shape (Q, KV). + pos_emb (torch.Tensor): + Position encoding embedding, with shape (PE, D). + where PE = 2 * U - 1. + + Returns: + A tuple containing 2 tensors: + - output of right context and utterance, with shape (R + U, B, D). + - memory output, with shape (M, B, D), where M = S - 1 or M = 0. + """ + ( + output_right_context_utterance, + output_memory, + _, + _, + ) = self._forward_impl( + utterance, + lengths, + right_context, + summary, + memory, + attention_mask, + pos_emb, + ) + return output_right_context_utterance, output_memory[:-1] + + def infer( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + summary: torch.Tensor, + memory: torch.Tensor, + left_context_key: torch.Tensor, + left_context_val: torch.Tensor, + pos_emb: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass for inference. + + B: batch size; + D: embedding dimension; + R: length of right context; + U: length of utterance, i.e., current chunk; + L: length of cached left context; + S: length of summary vectors, S = 1; + M: length of cached memory vectors. + + It concatenates the right context, utterance (i.e., current chunk) + and summary vector of current chunk, to compute the query tensor: + query = [right_context, utterance, summary], + with length Q = R + U + S. + It concatenates the memory vectors, right context, left context, and + current chunk, to compute the key and value tensors: + key & value = [memory, right_context, left_context, utterance], + with length KV = M + R + L + U. + + The chunk-wise attention is: + chunk, right context (in query) -> + left context, chunk, right context, memory vectors (in key); + summary (in query) -> left context, chunk, right context (in key). + + Relative positional encoding is applied on the part of attention between + chunk (in query) and chunk itself as well as its left context (in key): + chunk (in query) -> left context, chunk (in key). + + Args: + utterance (torch.Tensor): + Current chunk frames, with shape (U, B, D), where U = chunk_length. + lengths (torch.Tensor): + With shape (B,) and i-th element representing + number of valid frames for i-th batch element in utterance. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D), + where R = right_context_length. + summary (torch.Tensor): + Summary vector with shape (1, B, D), or empty tensor. + memory (torch.Tensor): + Memory vectors, with shape (M, B, D), or empty tensor. + left_context_key (torch,Tensor): + Cached attention key of left context from preceding computation, + with shape (L, B, D), where L <= left_context_length. + left_context_val (torch.Tensor): + Cached attention value of left context from preceding computation, + with shape (L, B, D), where L <= left_context_length. + pos_emb (torch.Tensor): + Position encoding embedding, with shape (PE, D), + where PE = L + 2 * U - 1. + + Returns: + A tuple containing 4 tensors: + - output of right context and utterance, with shape (R + U, B, D). + - memory output, with shape (1, B, D) or (0, B, D). + - attention key of left context and utterance, which would be cached + for next computation, with shape (L + U, B, D). + - attention value of left context and utterance, which would be + cached for next computation, with shape (L + U, B, D). + """ + U = utterance.size(0) + R = right_context.size(0) + L = left_context_key.size(0) + S = summary.size(0) + M = memory.size(0) + + # query = [right context, utterance, summary] + Q = R + U + S + # key, value = [memory, right context, left context, uttrance] + KV = M + R + L + U + attention_mask = torch.zeros(Q, KV).to( + dtype=torch.bool, device=utterance.device + ) + # disallow attention bettween the summary vector with the memory bank + attention_mask[-1, :M] = True + ( + output_right_context_utterance, + output_memory, + key, + value, + ) = self._forward_impl( + utterance, + lengths, + right_context, + summary, + memory, + attention_mask, + pos_emb, + left_context_key=left_context_key, + left_context_val=left_context_val, + ) + return ( + output_right_context_utterance, + output_memory, + key[M + R :], + value[M + R :], + ) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py index 528931a54..03835f0d7 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py @@ -13,5 +13,96 @@ def test_rel_positional_encoding(): assert pos_emb.shape == (pos_len + neg_len - 1, D) +def test_emformer_attention_forward(): + from emformer import EmformerAttention + + B, D = 2, 256 + chunk_length = 4 + right_context_length = 2 + num_chunks = 3 + U = num_chunks * chunk_length + R = num_chunks * right_context_length + attention = EmformerAttention(embed_dim=D, nhead=8) + + for use_memory in [True, False]: + if use_memory: + S = num_chunks + M = S - 1 + else: + S, M = 0, 0 + + Q, KV = R + U + S, M + R + U + utterance = torch.randn(U, B, D) + lengths = torch.randint(1, U + 1, (B,)) + lengths[0] = U + right_context = torch.randn(R, B, D) + summary = torch.randn(S, B, D) + memory = torch.randn(M, B, D) + attention_mask = torch.rand(Q, KV) >= 0.5 + PE = 2 * U - 1 + pos_emb = torch.randn(PE, D) + + output_right_context_utterance, output_memory = attention( + utterance, + lengths, + right_context, + summary, + memory, + attention_mask, + pos_emb, + ) + assert output_right_context_utterance.shape == (R + U, B, D) + assert output_memory.shape == (M, B, D) + + +def test_emformer_attention_infer(): + from emformer import EmformerAttention + + B, D = 2, 256 + U = 4 + R = 2 + L = 3 + attention = EmformerAttention(embed_dim=D, nhead=8) + + for use_memory in [True, False]: + if use_memory: + S, M = 1, 3 + else: + S, M = 0, 0 + + utterance = torch.randn(U, B, D) + lengths = torch.randint(1, U + 1, (B,)) + lengths[0] = U + right_context = torch.randn(R, B, D) + summary = torch.randn(S, B, D) + memory = torch.randn(M, B, D) + left_context_key = torch.randn(L, B, D) + left_context_val = torch.randn(L, B, D) + PE = L + 2 * U - 1 + pos_emb = torch.randn(PE, D) + + ( + output_right_context_utterance, + output_memory, + next_key, + next_val, + ) = attention.infer( + utterance, + lengths, + right_context, + summary, + memory, + left_context_key, + left_context_val, + pos_emb, + ) + assert output_right_context_utterance.shape == (R + U, B, D) + assert output_memory.shape == (S, B, D) + assert next_key.shape == (L + U, B, D) + assert next_val.shape == (L + U, B, D) + + if __name__ == "__main__": test_rel_positional_encoding() + test_emformer_attention_forward() + test_emformer_attention_infer()