# Copyright 2022 Xiaomi Corp. (author: Quandong Wang) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple import torch import torch.nn as nn from torch import Tensor from torch.nn.init import xavier_normal_ from scaling import ScaledLinear class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information from different representation subspaces. See `Attention Is All You Need `_. .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. Args: embed_dim: Total dimension of the model. num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). bias: If specified, adds bias to input / output projection layers. Default: ``True``. add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: ``False``. kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature). Examples:: >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value) """ __constants__ = ["batch_first"] bias_k: Optional[torch.Tensor] bias_v: Optional[torch.Tensor] def __init__( self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(MultiheadAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self._qkv_same_embed_dim = ( self.kdim == embed_dim and self.vdim == embed_dim ) self.num_heads = num_heads self.dropout = dropout self.batch_first = batch_first self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" if self._qkv_same_embed_dim is False: self.q_proj_weight = ScaledLinear(embed_dim, embed_dim, bias=bias) self.k_proj_weight = ScaledLinear(self.kdim, embed_dim, bias=bias) self.v_proj_weight = ScaledLinear(self.vdim, embed_dim, bias=bias) self.register_parameter("in_proj_weight", None) else: self.in_proj_weight = ScaledLinear( embed_dim, 3 * embed_dim, bias=bias ) self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) if not bias: self.register_parameter("in_proj_bias", None) self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=bias) if add_bias_kv: self.bias_k = nn.Parameter( torch.empty((1, 1, embed_dim), **factory_kwargs) ) self.bias_v = nn.Parameter( torch.empty((1, 1, embed_dim), **factory_kwargs) ) else: self.bias_k = self.bias_v = None self.add_zero_attn = add_zero_attn self._reset_parameters() def _reset_parameters(self): if self.bias_k is not None: xavier_normal_(self.bias_k) if self.bias_v is not None: xavier_normal_(self.bias_v) def __setstate__(self, state): # Support loading old MultiheadAttention checkpoints generated by v1.1.0 if "_qkv_same_embed_dim" not in state: state["_qkv_same_embed_dim"] = True super(MultiheadAttention, self).__setstate__(state) def forward( self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query: Query embeddings of shape :math:`(L, N, E_q)` when ``batch_first=False`` or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against key-value pairs to produce the output. See "Attention Is All You Need" for more details. key: Key embeddings of shape :math:`(S, N, E_k)` when ``batch_first=False`` or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details. value: Value embeddings of shape :math:`(S, N, E_v)` when ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details. key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of attention (i.e. treat as "padding"). Binary and byte masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key`` value will be ignored. need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. Default: ``True``. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight. Outputs: - **attn_output** - Attention outputs of shape :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the embedding dimension ``embed_dim``. - **attn_output_weights** - Attention output weights of shape :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. Only returned when ``need_weights=True``. """ if self.batch_first: query, key, value = [x.transpose(1, 0) for x in (query, key, value)] if not self._qkv_same_embed_dim: q_proj_weight = ( self.q_proj_weight.get_weight() if self.q_proj_weight is not None else None ) k_proj_weight = ( self.k_proj_weight.get_weight() if self.k_proj_weight is not None else None ) v_proj_weight = ( self.v_proj_weight.get_weight() if self.v_proj_weight is not None else None ) ( attn_output, attn_output_weights, ) = nn.functional.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight.get_weight(), self.in_proj_weight.get_bias(), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.get_weight(), self.out_proj.get_bias(), training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, use_separate_proj_weight=True, q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, v_proj_weight=v_proj_weight, ) else: ( attn_output, attn_output_weights, ) = nn.functional.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight.get_weight(), self.in_proj_weight.get_bias(), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.get_weight(), self.out_proj.get_bias(), training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, ) if self.batch_first: return attn_output.transpose(1, 0), attn_output_weights else: return attn_output, attn_output_weights