diff --git a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py index 14e106460..bde228af7 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py @@ -80,13 +80,22 @@ class EmformerAttention(nn.Module): self.nhead = nhead self.tanh_on_mem = tanh_on_mem self.negative_inf = negative_inf + self.head_dim = embed_dim // nhead - self.scaling = (self.embed_dim // self.nhead) ** -0.5 + self.scaling = self.head_dim ** -0.5 self.emb_to_key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = nn.Linear(embed_dim, embed_dim, bias=True) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa + self.pos_bias_u = nn.Parameter(torch.Tensor(nhead, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(nhead, self.head_dim)) + self._reset_parameters() def _reset_parameters(self) -> None: @@ -99,6 +108,11 @@ class EmformerAttention(nn.Module): nn.init.xavier_uniform_(self.out_proj.weight) nn.init.constant_(self.out_proj.bias, 0.0) + nn.init.xavier_uniform_(self.linear_pos.weight) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + def _gen_attention_probs( self, attention_weights: torch.Tensor, @@ -152,6 +166,32 @@ class EmformerAttention(nn.Module): return attention_probs + def _rel_shift(self, x: torch.Tensor) -> torch.Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor, of shape (B, nhead, U, PE). + U is the length of query vector. + For non-infer mode, PE = 2 * U - 1; + for infer mode, PE = L + 2 * U - 1. + + Returns: + A tensor of shape (B, nhead, U, out_len). + For non-infer mode, out_len = U; + for infer mode, out_len = L + U. + """ + B, nhead, U, PE = x.size() + B_stride = x.stride(0) + nhead_stride = x.stride(1) + U_stride = x.stride(2) + PE_stride = x.stride(3) + out_len = PE - (U - 1) + return x.as_strided( + size=(B, nhead, U, out_len), + stride=(B_stride, nhead_stride, U_stride - PE_stride, PE_stride), + storage_offset=PE_stride * (U - 1), + ) + def _forward_impl( self, utterance: torch.Tensor, @@ -160,6 +200,7 @@ class EmformerAttention(nn.Module): summary: torch.Tensor, memory: torch.Tensor, attention_mask: torch.Tensor, + pos_emb: torch.Tensor, left_context_key: Optional[torch.Tensor] = None, left_context_val: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -194,6 +235,10 @@ class EmformerAttention(nn.Module): Memory elements, with shape (M, B, D). attention_mask (torch.Tensor): Attention mask for underlying attention, with shape (Q, KV). + pos_emb (torch.Tensor): + Position encoding embedding, with shape (PE, D). + For training mode, PE = 2*U-1; + For infer mode, PE = L+2*U-1. left_context_key (torch,Tensor, optional): Cached attention key of left context from preceding computation, with shape (L, B, D). @@ -208,7 +253,9 @@ class EmformerAttention(nn.Module): - attention key, with shape (KV, B, D). - attention value, with shape (KV, B, D). """ - B = utterance.size(1) + U, B, _ = utterance.size() + R = right_context.size(0) + M = memory.size(0) # Compute query with [right context, utterance, summary]. query = self.emb_to_query( @@ -222,41 +269,71 @@ class EmformerAttention(nn.Module): if left_context_key is not None and left_context_val is not None: # This is for inference mode. Now compute key and value with # [mems, right context, left context, uttrance] - M = memory.size(0) - R = right_context.size(0) - right_context_end_idx = M + R key = torch.cat( - [ - key[:right_context_end_idx], - left_context_key, - key[right_context_end_idx:], - ] + [key[: M + R], left_context_key, key[M + R :]] # noqa ) value = torch.cat( - [ - value[:right_context_end_idx], - left_context_val, - value[right_context_end_idx:], - ] + [value[: M + R], left_context_val, value[M + R :]] # noqa ) + Q = query.size(0) + KV = key.size(0) - # Compute attention weights from query, key, and value. - reshaped_query, reshaped_key, reshaped_value = [ + reshaped_key, reshaped_value = [ tensor.contiguous() - .view(-1, B * self.nhead, self.embed_dim // self.nhead) + .view(KV, B * self.nhead, self.head_dim) .transpose(0, 1) - for tensor in [query, key, value] - ] - attention_weights = torch.bmm( - reshaped_query * self.scaling, reshaped_key.transpose(1, 2) + for tensor in [key, value] + ] # (B * nhead, KV, head_dim) + reshaped_query = query.contiguous().view( + Q, B, self.nhead, self.head_dim ) + # compute attention matrix ac + query_with_bais_u = ( + (reshaped_query + self.pos_bias_u) + .view(Q, B * self.nhead, self.head_dim) + .transpose(0, 1) + ) + matrix_ac = torch.bmm( + query_with_bais_u, reshaped_key.transpose(1, 2) + ) # (B * nhead, Q, KV) + + # compute attention matrix bd + utterance_with_bais_v = ( + reshaped_query[R : R + U] + self.pos_bias_v + ).permute(1, 2, 0, 3) + # (B, nhead, U, head_dim) + PE = pos_emb.size(0) + if left_context_key is not None and left_context_val is not None: + L = left_context_key.size(0) + assert PE == L + 2 * U - 1 + else: + assert PE == 2 * U - 1 + pos_emb = ( + self.linear_pos(pos_emb) + .view(PE, self.nhead, self.head_dim) + .transpose(0, 1) + .unsqueeze(0) + ) # (1, nhead, PE, head_dim) + matrix_bd_utterance = torch.matmul( + utterance_with_bais_v, pos_emb.transpose(-2, -1) + ) # (B, nhead, U, PE) + # rel-shift + matrix_bd_utterance = self._rel_shift( + matrix_bd_utterance + ) # (B, nhead, U, U or L + U) + matrix_bd_utterance = matrix_bd_utterance.contiguous().view( + B * self.nhead, U, -1 + ) + matrix_bd = torch.zeros_like(matrix_ac) + matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance + + attention_weights = (matrix_ac + matrix_bd) * self.scaling + # Compute padding mask if B == 1: padding_mask = None else: - KV = key.size(0) - U = utterance.size(0) padding_mask = make_pad_mask(KV - U + lengths) # Compute attention probabilities. @@ -266,12 +343,7 @@ class EmformerAttention(nn.Module): # Compute attention. attention = torch.bmm(attention_probs, reshaped_value) - Q = query.size(0) - assert attention.shape == ( - B * self.nhead, - Q, - self.embed_dim // self.nhead, - ) + assert attention.shape == (B * self.nhead, Q, self.head_dim) attention = ( attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) ) @@ -279,10 +351,8 @@ class EmformerAttention(nn.Module): # Apply output projection. outputs = self.out_proj(attention) - S = summary.size(0) - summary_start_idx = Q - S - output_right_context_utterance = outputs[:summary_start_idx] - output_memory = outputs[summary_start_idx:] + output_right_context_utterance = outputs[: R + U] + output_memory = outputs[R + U :] if self.tanh_on_mem: output_memory = torch.tanh(output_memory) else: @@ -298,6 +368,7 @@ class EmformerAttention(nn.Module): summary: torch.Tensor, memory: torch.Tensor, attention_mask: torch.Tensor, + pos_emb: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: # TODO: Modify docs. """Forward pass for training. @@ -324,6 +395,9 @@ class EmformerAttention(nn.Module): attention_mask (torch.Tensor): Attention mask for underlying chunk-wise attention, with shape (Q, KV), where Q = R + U + S, KV = M + R + U. + pos_emb (torch.Tensor): + Position encoding embedding, with shape (PE, D). + For training mode, P = 2*U-1. Returns: A tuple containing 2 tensors: @@ -336,7 +410,13 @@ class EmformerAttention(nn.Module): _, _, ) = self._forward_impl( - utterance, lengths, right_context, summary, memory, attention_mask + utterance, + lengths, + right_context, + summary, + memory, + attention_mask, + pos_emb, ) return output_right_context_utterance, output_memory[:-1] @@ -350,6 +430,7 @@ class EmformerAttention(nn.Module): memory: torch.Tensor, left_context_key: torch.Tensor, left_context_val: torch.Tensor, + pos_emb: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Forward pass for inference. @@ -379,6 +460,9 @@ class EmformerAttention(nn.Module): left_context_val (torch.Tensor): Cached attention value of left context from preceding computation, with shape (L, B, D). + pos_emb (torch.Tensor): + Position encoding embedding, with shape (PE, D). + For infer mode, PE = L+2*U-1. Returns: A tuple containing 4 tensors: @@ -394,9 +478,9 @@ class EmformerAttention(nn.Module): # key, value: [memory, right context, left context, uttrance] KV = ( memory.size(0) - + right_context.size(0) - + left_context_key.size(0) - + utterance.size(0) + + right_context.size(0) # noqa + + left_context_key.size(0) # noqa + + utterance.size(0) # noqa ) attention_mask = torch.zeros(Q, KV).to( dtype=torch.bool, device=utterance.device @@ -415,6 +499,7 @@ class EmformerAttention(nn.Module): summary, memory, attention_mask, + pos_emb, left_context_key=left_context_key, left_context_val=left_context_val, ) @@ -643,6 +728,7 @@ class EmformerLayer(nn.Module): right_context_end_idx: int, lengths: torch.Tensor, memory: torch.Tensor, + pos_emb: torch.Tensor, attention_mask: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply attention module in non-infer mode.""" @@ -671,6 +757,7 @@ class EmformerLayer(nn.Module): summary=summary, memory=memory, attention_mask=attention_mask, + pos_emb=pos_emb, ) right_context_utterance = residual + self.dropout( output_right_context_utterance @@ -684,6 +771,7 @@ class EmformerLayer(nn.Module): right_context_end_idx: int, lengths: torch.Tensor, memory: torch.Tensor, + pos_emb: torch.Tensor, state: Optional[List[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: """Apply attention in infer mode. @@ -717,6 +805,14 @@ class EmformerLayer(nn.Module): summary = torch.empty(0).to( dtype=utterance.dtype, device=utterance.device ) + # pos_emb is of shape [PE, D], PE = L + 2 * U - 1, + # the relative distance j - i of key(j) and query(i) is in range of [-(L + U - 1), (U - 1)] # noqa + L = left_context_key.size(0) # L <= left_context_length + U = utterance.size(0) + PE = L + 2 * U - 1 + tot_PE = self.left_context_length + 2 * U - 1 + assert pos_emb.size(0) == tot_PE + pos_emb = pos_emb[tot_PE - PE :] ( output_right_context_utterance, output_memory, @@ -730,6 +826,7 @@ class EmformerLayer(nn.Module): memory=pre_memory, left_context_key=left_context_key, left_context_val=left_context_val, + pos_emb=pos_emb, ) right_context_utterance = residual + self.dropout( output_right_context_utterance @@ -746,6 +843,7 @@ class EmformerLayer(nn.Module): right_context: torch.Tensor, memory: torch.Tensor, attention_mask: torch.Tensor, + pos_emb: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r"""Forward pass for training. 1) Apply layer normalization on input utterance and right context @@ -774,6 +872,9 @@ class EmformerLayer(nn.Module): attention_mask (torch.Tensor): Attention mask for underlying attention module, with shape (Q, KV), where Q = R + U + S, KV = M + R + U. + pos_emb (torch.Tensor): + Position encoding embedding, with shape (PE, D). + For training mode, P = 2*U-1. Returns: A tuple containing 3 tensors: @@ -797,6 +898,7 @@ class EmformerLayer(nn.Module): lengths, memory, attention_mask, + pos_emb, ) right_context_utterance = self._apply_conv_module_forward( @@ -820,6 +922,7 @@ class EmformerLayer(nn.Module): lengths: torch.Tensor, right_context: torch.Tensor, memory: torch.Tensor, + pos_emb: torch.Tensor, state: Optional[List[torch.Tensor]] = None, conv_cache: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]: @@ -851,6 +954,9 @@ class EmformerLayer(nn.Module): state (List[torch.Tensor], optional): List of tensors representing layer internal state generated in preceding computation. (default=None) + pos_emb (torch.Tensor): + Position encoding embedding, with shape (PE, D). + For infer mode, PE = L+2*U-1. conv_cache (torch.Tensor, optional): Cache tensor of left context for causal convolution. @@ -878,6 +984,7 @@ class EmformerLayer(nn.Module): right_context_end_idx, lengths, memory, + pos_emb, state, ) @@ -1124,7 +1231,10 @@ class EmformerEncoder(nn.Module): return attention_mask def forward( - self, x: torch.Tensor, lengths: torch.Tensor + self, + x: torch.Tensor, + lengths: torch.Tensor, + pos_emb: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for training and non-streaming inference. @@ -1140,6 +1250,9 @@ class EmformerEncoder(nn.Module): With shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in x, which contains the right_context at the end. + pos_emb (torch.Tensor): + Position encoding embedding, with shape (PE, D). + For training mode, P = 2*U-1. Returns: A tuple of 2 tensors: @@ -1161,7 +1274,12 @@ class EmformerEncoder(nn.Module): output = utterance for layer in self.emformer_layers: output, right_context, memory = layer( - output, output_lengths, right_context, memory, attention_mask + output, + output_lengths, + right_context, + memory, + attention_mask, + pos_emb, ) return output, output_lengths @@ -1171,6 +1289,7 @@ class EmformerEncoder(nn.Module): self, x: torch.Tensor, lengths: torch.Tensor, + pos_emb: torch.Tensor, states: Optional[List[List[torch.Tensor]]] = None, conv_caches: Optional[List[torch.Tensor]] = None, ) -> Tuple[ @@ -1190,6 +1309,9 @@ class EmformerEncoder(nn.Module): With shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in x, which contains the right_context at the end. + pos_emb (torch.Tensor): + Position encoding embedding, with shape (PE, D). + For infer mode, PE = L+2*U-1. states (List[List[torch.Tensor]], optional): Cached states from proceeding chunk's computation, where each element (List[torch.Tensor]) corresponds to each emformer layer. @@ -1234,6 +1356,7 @@ class EmformerEncoder(nn.Module): output_lengths, right_context, memory, + pos_emb, None if states is None else states[layer_idx], None if conv_caches is None else conv_caches[layer_idx], ) @@ -1291,6 +1414,8 @@ class Emformer(EncoderInterface): else: self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + self.encoder = EmformerEncoder( chunk_length // 4, d_model, @@ -1338,6 +1463,10 @@ class Emformer(EncoderInterface): right_context at the end. """ x = self.encoder_embed(x) + + # TODO: The length computation in the encoder class should be moved here. # noqa + U = x.size(1) - self.right_context_length // 4 + x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # Caution: We assume the subsampling factor is 4! @@ -1346,7 +1475,7 @@ class Emformer(EncoderInterface): x_lens = ((x_lens - 1) // 2 - 1) // 2 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder(x, x_lens) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens, pos_emb) # (T, N, C) logits = self.encoder_output_layer(output) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -1392,6 +1521,12 @@ class Emformer(EncoderInterface): - updated convolution caches from current chunk. """ x = self.encoder_embed(x) + + # TODO: The length computation in the encoder class should be moved here. # noqa + pos_len = self.chunk_length // 4 + self.left_context_length // 4 + neg_len = self.chunk_length // 4 + x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # Caution: We assume the subsampling factor is 4! @@ -1405,7 +1540,7 @@ class Emformer(EncoderInterface): output_lengths, output_states, output_conv_caches, - ) = self.encoder.infer(x, x_lens, states, conv_caches) + ) = self.encoder.infer(x, x_lens, pos_emb, states, conv_caches) logits = self.encoder_output_layer(output) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -1533,6 +1668,111 @@ class ConvolutionModule(nn.Module): return x.permute(2, 0, 1), new_cache +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" # noqa + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py # noqa + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.pos_len = max_len + self.neg_len = max_len + self.gen_pe() + + def gen_pe(self) -> None: + """Generate the positional encodings.""" + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i torch.Tensor: + """Get positional encoding given positive length and negative length.""" + if self.pe_positive.dtype != dtype or str( + self.pe_positive.device + ) != str(device): + self.pe_positive = self.pe_positive.to(dtype=dtype, device=device) + if self.pe_negative.dtype != dtype or str( + self.pe_negative.device + ) != str(device): + self.pe_negative = self.pe_negative.to(dtype=dtype, device=device) + pe = torch.cat( + [ + self.pe_positive[self.pos_len - pos_len :], + self.pe_negative[1:neg_len], + ], + dim=0, + ) + return pe + + def forward( + self, + x: torch.Tensor, + pos_len: int, + neg_len: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + x = x * self.xscale + if pos_len > self.pos_len or neg_len > self.neg_len: + self.pos_len = pos_len + self.neg_len = neg_len + self.gen_pe() + pos_emb = self.get_pe(pos_len, neg_len, x.device, x.dtype) + return self.dropout(x), self.dropout(pos_emb) + + class Swish(torch.nn.Module): """Construct an Swish object."""