mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 04:22:21 +00:00
1483 lines
54 KiB
Python
1483 lines
54 KiB
Python
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
|
# Apache 2.0
|
|
|
|
import math
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
|
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
|
Supervisions = Dict[str, torch.Tensor]
|
|
|
|
|
|
class MaskedLmConformer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_classes: int,
|
|
d_model: int = 256,
|
|
nhead: int = 4,
|
|
dim_feedforward: int = 2048,
|
|
num_encoder_layers: int = 12,
|
|
num_decoder_layers: int = 6,
|
|
dropout: float = 0.1,
|
|
cnn_module_kernel: int = 31,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
num_classes:
|
|
The input and output dimension of the model (inputs and outputs are
|
|
both discrete)
|
|
d_model:
|
|
Attention dimension.
|
|
nhead:
|
|
Number of heads in multi-head attention.
|
|
Must satisfy d_model // nhead == 0.
|
|
dim_feedforward:
|
|
The output dimension of the feedforward layers in encoder/decoder.
|
|
num_encoder_layers:
|
|
Number of encoder layers.
|
|
num_decoder_layers:
|
|
Number of decoder layers.
|
|
dropout:
|
|
Dropout in encoder/decoder.
|
|
"""
|
|
super(MaskedLmConformer, self).__init__()
|
|
|
|
self.num_classes = num_classes
|
|
|
|
# self.embed is the embedding used for both the encoder and decoder.
|
|
self.embed_scale = d_model ** 0.5
|
|
self.embed = nn.Embedding(
|
|
num_embeddings=self.decoder_num_class, embedding_dim=d_model,
|
|
_weight=torch.randn(self.decoder_num_class, d_model) * (1 / self.embed_scale)
|
|
)
|
|
|
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
|
|
|
encoder_layer = MaskedLmConformerEncoderLayer(
|
|
d_model,
|
|
nhead,
|
|
dim_feedforward,
|
|
dropout,
|
|
cnn_module_kernel,
|
|
)
|
|
self.encoder = MaskedLmConformerEncoder(encoder_layer, num_encoder_layers,
|
|
norm=nn.LayerNorm(d_model))
|
|
|
|
if num_decoder_layers > 0:
|
|
self.decoder_num_class = self.num_classes
|
|
|
|
decoder_layer = TransformerDecoderLayerRelPos(
|
|
d_model=d_model,
|
|
nhead=nhead,
|
|
dim_feedforward=dim_feedforward,
|
|
dropout=dropout,
|
|
)
|
|
|
|
# Projects the embedding of `src`, to be added to `memory`
|
|
self.src_linear = torch.nn.Linear(d_model, d_model)
|
|
|
|
decoder_norm = nn.LayerNorm(d_model)
|
|
self.decoder = TransformerDecoderRelPos(
|
|
decoder_layer=decoder_layer,
|
|
num_layers=num_decoder_layers,
|
|
norm=decoder_norm,
|
|
)
|
|
|
|
self.decoder_output_layer = torch.nn.Linear(
|
|
d_model, self.decoder_num_class
|
|
)
|
|
|
|
|
|
def forward(
|
|
self,
|
|
masked_src_symbols: torch.Tensor,
|
|
key_padding_mask: torch.Tensor = None
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Args:
|
|
masked_src_symbols:
|
|
The input symbols to be embedded (will actually have query positions
|
|
masked), as a Tensor of shape (batch_size, seq_len) and dtype=torch.int64.
|
|
I.e. shape (N, T)
|
|
key_padding_mask:
|
|
Either None, or a Tensor of shape (batch_size, seq_len) i.e. (N, T),
|
|
and dtype=torch.bool which has True in positions to be masked in attention
|
|
layers and convolutions because they represent padding at the ends of
|
|
sequences.
|
|
|
|
|
|
Returns:
|
|
Returns (encoded, pos_emb), where:
|
|
`encoded` is a Tensor containing the encoded data; it is of shape (N, T, C)
|
|
where C is the embedding_dim.
|
|
`pos_emb` is a Tensor containing the relative positional encoding, of
|
|
shape (1, 2*T-1, C)
|
|
"""
|
|
|
|
x = self.embed(masked_src_symbols) * self.embed_scale # (N, T, C)
|
|
x, pos_emb = self.encoder_pos(x) # pos_emb: (1, 2*T-1, C)
|
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
|
|
|
x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C)
|
|
|
|
return x, pos_emb
|
|
|
|
def decoder_nll(
|
|
self,
|
|
memory: torch.Tensor,
|
|
pos_emb: torch.Tensor,
|
|
src_symbols: torch.Tensor,
|
|
tgt_symbols: torch.Tensor,
|
|
key_padding_mask: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
memory:
|
|
The output of the encoder, with shape (T, N, C)
|
|
pos_emb:
|
|
Relative positional embedding, of shape (1, 2*T-1, C), as
|
|
returned from the encoder
|
|
src_symbols:
|
|
The un-masked src symbols, a LongTensor of shape (N, T).
|
|
Can be used to predict the target
|
|
only in a left-to-right manner (otherwise it's cheating).
|
|
tgt_symbols:
|
|
Target symbols, a LongTensor of shape (N, T).
|
|
The same as src_symbols, but shifted by one (and also,
|
|
without symbol randomization, see randomize_proportion
|
|
in dataloader)
|
|
key_padding_mask:
|
|
A BoolTensor of shape (N, T), with True for positions
|
|
that correspond to padding at the end of source and
|
|
memory sequences. The same mask is used for self-attention
|
|
and cross-attention, since the padding is the same.
|
|
|
|
Returns:
|
|
Returns a tensor of shape (N, T), containing the negative
|
|
log-probabilities for the target symbols at each position
|
|
in the target sequence.
|
|
"""
|
|
(T, N, C) = memory.shape
|
|
|
|
tgt_mask = generate_square_subsequent_mask(T, memory.device)
|
|
|
|
src = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C)
|
|
src = src.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
|
|
|
|
|
src = memory + self.src_linear(src) # (T, N, C)
|
|
|
|
# This is a little confusing, how "tgt" is set to src. "src" is the
|
|
# symbol sequence without masking but with padding and randomization.
|
|
# "tgt" is like "src" but shifted by one.
|
|
pred = self.decoder(
|
|
tgt=src,
|
|
memory=memory,
|
|
tgt_mask=tgt_mask,
|
|
tgt_key_padding_mask=key_padding_mask,
|
|
memory_key_padding_mask=key_padding_mask,
|
|
) # (T, N, C)
|
|
|
|
pred = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
|
pred = self.decoder_output_layer(pred) # (N, T, C)
|
|
|
|
# nll: negative log-likelihood
|
|
nll = torch.nn.functional.cross_entropy(
|
|
pred.view(-1, self.decoder_num_class),
|
|
tgt_symbols.view(-1),
|
|
reduction="none",
|
|
)
|
|
nll = nll.view(N, T)
|
|
return nll
|
|
|
|
|
|
|
|
|
|
class TransformerDecoderRelPos(nn.Module):
|
|
r"""TransformerDecoderRelPos is a stack of N decoder layers.
|
|
This is modified from nn.TransformerDecoder to support relative positional
|
|
encoding.
|
|
|
|
Args:
|
|
decoder_layer: an instance of the TransformerDecoderLayerRelPos() class (required).
|
|
num_layers: the number of sub-decoder-layers in the decoder (required).
|
|
norm: the layer normalization component (optional).
|
|
|
|
Examples::
|
|
>>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8)
|
|
>>> transformer_decoder = nn.TransformerDecoderRelPos(decoder_layer, num_layers=6)
|
|
>>> memory = torch.rand(10, 32, 512)
|
|
>>> tgt = torch.rand(20, 32, 512)
|
|
>>> pos_enc = torch.rand()
|
|
>>> out = transformer_decoder(tgt, memory)
|
|
"""
|
|
__constants__ = ['norm']
|
|
|
|
def __init__(self, decoder_layer, num_layers, norm=None):
|
|
super(TransformerDecoderRelPos, self).__init__()
|
|
self.layers = _get_clones(decoder_layer, num_layers)
|
|
self.num_layers = num_layers
|
|
self.norm = norm
|
|
|
|
def forward(self, x: Tensor,
|
|
pos_emb: Tensor,
|
|
memory: Tensor,
|
|
attn_mask: Optional[Tensor] = None,
|
|
key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
|
r"""Pass the inputs (and mask) through the decoder layer in turn.
|
|
|
|
Args:
|
|
x: the input embedding sequence to the decoder (required): shape = (T, N, C).
|
|
Will be an embedding of `src_symbols` in practice
|
|
pos_emb:
|
|
A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels,
|
|
representing the relative positional encoding.
|
|
memory: the sequence from the last layer of the encoder (required):
|
|
shape = (T, N, C)
|
|
attn_mask: the mask for the `x` sequence's attention to itself,
|
|
of shape (T, T); in practice, will ensure that no
|
|
position can attend to later positions. A torch.Tensor with dtype=torch.float
|
|
or dtype=torch.bool.
|
|
key_padding_mask: the key-padding mask for both the memory and x sequences,
|
|
a torch.Tensor with dtype=bool and shape (N, T): true for masked
|
|
positions after the ends of sequences.
|
|
"""
|
|
|
|
for mod in self.layers:
|
|
x = mod(x, pos_emb, memory, x_mask=x_mask,
|
|
key_padding_mask=key_padding_mask)
|
|
|
|
if self.norm is not None:
|
|
output = self.norm(output)
|
|
|
|
return output
|
|
|
|
|
|
class TransformerDecoderLayerRelPos(nn.Module):
|
|
"""
|
|
Modified from torch.nn.TransformerDecoderLayer.
|
|
Add it to use normalize_before (hardcoded to True), i.e. use layer_norm before the first block;
|
|
to use relative positional encoding; and for some changes/simplifications in interface
|
|
because both sequences are the same length and have the same mask.
|
|
|
|
Args:
|
|
d_model:
|
|
the number of expected features in the input (required).
|
|
nhead:
|
|
the number of heads in the multiheadattention models (required).
|
|
dim_feedforward:
|
|
the dimension of the feedforward network model (default=2048).
|
|
dropout:
|
|
the dropout value (default=0.1).
|
|
activation:
|
|
the activation function of intermediate layer, relu or
|
|
gelu (default=relu).
|
|
|
|
Examples::
|
|
>>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8)
|
|
>>> memory = torch.rand(10, 32, 512)
|
|
>>> tgt = torch.rand(20, 32, 512)
|
|
>>> pos_emb = torch.rand(1, 20*2+1, 512)
|
|
>>> out = decoder_layer(tgt, pos_emb, memory)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
d_model: int,
|
|
nhead: int,
|
|
dim_feedforward: int = 2048,
|
|
dropout: float = 0.1,
|
|
activation: str = "relu",
|
|
) -> None:
|
|
super(TransformerDecoderLayer, self).__init__()
|
|
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
|
self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
|
# Implementation of Feedforward model
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
self.norm3 = nn.LayerNorm(d_model)
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
self.dropout3 = nn.Dropout(dropout)
|
|
|
|
self.activation = _get_activation_fn(activation)
|
|
|
|
|
|
def __setstate__(self, state):
|
|
if "activation" not in state:
|
|
state["activation"] = nn.functional.relu
|
|
super(TransformerDecoderLayer, self).__setstate__(state)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
pos_emb: torch.Tensor,
|
|
memory: torch.Tensor,
|
|
x_mask: Optional[torch.Tensor] = None,
|
|
key_padding_mask: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""Pass the inputs (and mask) through the decoder layer.
|
|
|
|
Args:
|
|
x
|
|
The input embedding, to be added to by the forward function, of shape (T, N, C).
|
|
Attention within x will be left-to-right only (causal), thanks to x_mask.
|
|
pos_emb:
|
|
A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels,
|
|
containing the relative positional encoding.
|
|
memory:
|
|
the sequence from the last layer of the encoder (required). Shape = (T, N, C)
|
|
x_mask:
|
|
the mask for the x, to enforce causal (left to right) attention (optional).
|
|
Shape == (T, T); may be bool or float. The first T pertains to the output,
|
|
the second T to the input.
|
|
key_padding_mask:
|
|
the key-padding mask to use for both the x and memory sequences. Shep == (N, T);
|
|
may be bool (True==masked) or float (to be added to attention scores).
|
|
|
|
Returns:
|
|
Returns 'x plus something', a torch.Tensor with dtype the same as x (e.g. float),
|
|
and shape (T, N, C).
|
|
"""
|
|
residual = x
|
|
x = self.norm1(x)
|
|
self_attn = self.self_attn(x, x, x,
|
|
key_padding_mask=key_padding_mask,
|
|
need_weights=False,
|
|
attn_mask=x_mask,
|
|
)[0]
|
|
x = residual + self.dropout1(self_attn)
|
|
|
|
residual = x
|
|
x = self.norm2(x)
|
|
src_attn = self.src_attn(x, memory, memory,
|
|
key_padding_mask=key_padding_mask,
|
|
need_weights=False,
|
|
)[0]
|
|
x = residual + self.dropout2(src_attn)
|
|
|
|
residual = x
|
|
x = self.norm3(x)
|
|
ff = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
|
x = residual + self.dropout3(ff)
|
|
return x
|
|
|
|
|
|
def _get_activation_fn(activation: str):
|
|
if activation == "relu":
|
|
return nn.functional.relu
|
|
elif activation == "gelu":
|
|
return nn.functional.gelu
|
|
|
|
raise RuntimeError(
|
|
"activation should be relu/gelu, not {}".format(activation)
|
|
)
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
"""This class implements the positional encoding
|
|
proposed in the following paper:
|
|
|
|
- Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
|
|
|
|
PE(pos, 2i) = sin(pos / (10000^(2i/d_modle))
|
|
PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle))
|
|
|
|
Note::
|
|
|
|
1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model)))
|
|
= exp(-1* 2i / d_model * log(100000))
|
|
= exp(2i * -(log(10000) / d_model))
|
|
"""
|
|
|
|
def __init__(self, d_model: int, dropout: float = 0.1) -> None:
|
|
"""
|
|
Args:
|
|
d_model:
|
|
Embedding dimension.
|
|
dropout:
|
|
Dropout probability to be applied to the output of this module.
|
|
"""
|
|
super().__init__()
|
|
self.d_model = d_model
|
|
self.xscale = math.sqrt(self.d_model)
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
self.pe = None
|
|
|
|
def extend_pe(self, x: torch.Tensor) -> None:
|
|
"""Extend the time t in the positional encoding if required.
|
|
|
|
The shape of `self.pe` is [1, T1, d_model]. The shape of the input x
|
|
is [N, T, d_model]. If T > T1, then we change the shape of self.pe
|
|
to [N, T, d_model]. Otherwise, nothing is done.
|
|
|
|
Args:
|
|
x:
|
|
It is a tensor of shape [N, T, C].
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
if self.pe is not None:
|
|
if self.pe.size(1) >= x.size(1):
|
|
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
|
return
|
|
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
|
|
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
|
div_term = torch.exp(
|
|
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
|
* -(math.log(10000.0) / self.d_model)
|
|
)
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
pe = pe.unsqueeze(0)
|
|
# Now pe is of shape [1, T, d_model], where T is x.size(1)
|
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Add positional encoding.
|
|
|
|
Args:
|
|
x:
|
|
Its shape is [N, T, C]
|
|
|
|
Returns:
|
|
Return a tensor of shape [N, T, C]
|
|
"""
|
|
self.extend_pe(x)
|
|
x = x * self.xscale + self.pe[:, : x.size(1), :]
|
|
return self.dropout(x)
|
|
|
|
|
|
class Noam(object):
|
|
"""
|
|
Implements Noam optimizer.
|
|
|
|
Proposed in
|
|
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
|
|
|
|
Modified from
|
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
|
|
|
|
Args:
|
|
params:
|
|
iterable of parameters to optimize or dicts defining parameter groups
|
|
model_size:
|
|
attention dimension of the transformer model
|
|
factor:
|
|
learning rate factor
|
|
warm_step:
|
|
warmup steps
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
params,
|
|
model_size: int = 256,
|
|
factor: float = 10.0,
|
|
warm_step: int = 25000,
|
|
weight_decay=0,
|
|
) -> None:
|
|
"""Construct an Noam object."""
|
|
self.optimizer = torch.optim.Adam(
|
|
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
|
|
)
|
|
self._step = 0
|
|
self.warmup = warm_step
|
|
self.factor = factor
|
|
self.model_size = model_size
|
|
self._rate = 0
|
|
|
|
@property
|
|
def param_groups(self):
|
|
"""Return param_groups."""
|
|
return self.optimizer.param_groups
|
|
|
|
def step(self):
|
|
"""Update parameters and rate."""
|
|
self._step += 1
|
|
rate = self.rate()
|
|
for p in self.optimizer.param_groups:
|
|
p["lr"] = rate
|
|
self._rate = rate
|
|
self.optimizer.step()
|
|
|
|
def rate(self, step=None):
|
|
"""Implement `lrate` above."""
|
|
if step is None:
|
|
step = self._step
|
|
return (
|
|
self.factor
|
|
* self.model_size ** (-0.5)
|
|
* min(step ** (-0.5), step * self.warmup ** (-1.5))
|
|
)
|
|
|
|
def zero_grad(self):
|
|
"""Reset gradient."""
|
|
self.optimizer.zero_grad()
|
|
|
|
def state_dict(self):
|
|
"""Return state_dict."""
|
|
return {
|
|
"_step": self._step,
|
|
"warmup": self.warmup,
|
|
"factor": self.factor,
|
|
"model_size": self.model_size,
|
|
"_rate": self._rate,
|
|
"optimizer": self.optimizer.state_dict(),
|
|
}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Load state_dict."""
|
|
for key, value in state_dict.items():
|
|
if key == "optimizer":
|
|
self.optimizer.load_state_dict(state_dict["optimizer"])
|
|
else:
|
|
setattr(self, key, value)
|
|
|
|
|
|
class LabelSmoothingLoss(nn.Module):
|
|
"""
|
|
Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w)
|
|
and p_{prob. computed by model}(w) is minimized.
|
|
Modified from
|
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
|
|
|
|
Args:
|
|
size: the number of class
|
|
padding_idx: padding_idx: ignored class id
|
|
smoothing: smoothing rate (0.0 means the conventional CE)
|
|
normalize_length: normalize loss by sequence length if True
|
|
criterion: loss function to be smoothed
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
padding_idx: int = -1,
|
|
smoothing: float = 0.1,
|
|
normalize_length: bool = False,
|
|
criterion: nn.Module = nn.KLDivLoss(reduction="none"),
|
|
) -> None:
|
|
"""Construct an LabelSmoothingLoss object."""
|
|
super(LabelSmoothingLoss, self).__init__()
|
|
self.criterion = criterion
|
|
self.padding_idx = padding_idx
|
|
assert 0.0 < smoothing <= 1.0
|
|
self.confidence = 1.0 - smoothing
|
|
self.smoothing = smoothing
|
|
self.size = size
|
|
self.true_dist = None
|
|
self.normalize_length = normalize_length
|
|
|
|
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Compute loss between x and target.
|
|
|
|
Args:
|
|
x:
|
|
prediction of dimension
|
|
(batch_size, input_length, number_of_classes).
|
|
target:
|
|
target masked with self.padding_id of
|
|
dimension (batch_size, input_length).
|
|
|
|
Returns:
|
|
A scalar tensor containing the loss without normalization.
|
|
"""
|
|
assert x.size(2) == self.size
|
|
# batch_size = x.size(0)
|
|
x = x.view(-1, self.size)
|
|
target = target.view(-1)
|
|
with torch.no_grad():
|
|
true_dist = x.clone()
|
|
true_dist.fill_(self.smoothing / (self.size - 1))
|
|
ignore = target == self.padding_idx # (B,)
|
|
total = len(target) - ignore.sum().item()
|
|
target = target.masked_fill(ignore, 0) # avoid -1 index
|
|
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
|
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
|
# denom = total if self.normalize_length else batch_size
|
|
denom = total if self.normalize_length else 1
|
|
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|
|
|
|
|
|
|
|
def generate_square_subsequent_mask(sz: int, device: torch.device = torch.device('cpu')) -> torch.Tensor:
|
|
"""Generate a square mask for the sequence. The masked positions are
|
|
filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
|
The mask can be used for masked self-attention.
|
|
|
|
For instance, if sz is 3, it returns::
|
|
|
|
tensor([[0., -inf, -inf],
|
|
[0., 0., -inf],
|
|
[0., 0., 0]])
|
|
|
|
Args:
|
|
sz: mask size
|
|
|
|
Returns:
|
|
A square mask of dimension (sz, sz)
|
|
"""
|
|
mask = (torch.triu(torch.ones(sz, sz, device=device)) == 1).transpose(0, 1)
|
|
mask = (
|
|
mask.float()
|
|
.masked_fill(mask == 0, float("-inf"))
|
|
.masked_fill(mask == 1, float(0.0))
|
|
)
|
|
return mask
|
|
|
|
|
|
def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
|
|
"""Prepend sos_id to each utterance.
|
|
|
|
Args:
|
|
token_ids:
|
|
A list-of-list of token IDs. Each sublist contains
|
|
token IDs (e.g., word piece IDs) of an utterance.
|
|
sos_id:
|
|
The ID of the SOS token.
|
|
|
|
Return:
|
|
Return a new list-of-list, where each sublist starts
|
|
with SOS ID.
|
|
"""
|
|
ans = []
|
|
for utt in token_ids:
|
|
ans.append([sos_id] + utt)
|
|
return ans
|
|
|
|
|
|
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
|
|
"""Append eos_id to each utterance.
|
|
|
|
Args:
|
|
token_ids:
|
|
A list-of-list of token IDs. Each sublist contains
|
|
token IDs (e.g., word piece IDs) of an utterance.
|
|
eos_id:
|
|
The ID of the EOS token.
|
|
|
|
Return:
|
|
Return a new list-of-list, where each sublist ends
|
|
with EOS ID.
|
|
"""
|
|
ans = []
|
|
for utt in token_ids:
|
|
ans.append(utt + [eos_id])
|
|
return ans
|
|
|
|
|
|
|
|
class MaskedConvolutionModule(nn.Module):
|
|
"""
|
|
This is used in the MaskedLmConformerLayer. It is the same as the ConvolutionModule
|
|
of theConformer code, but with key_padding_mask supported to make the output independent
|
|
of the batching.
|
|
|
|
Modified, ultimately, from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
|
|
|
Args:
|
|
channels (int): The number of channels of conv layers.
|
|
kernel_size (int): Kernerl size of conv layers.
|
|
bias (bool): Whether to use bias in conv layers (default=True).
|
|
"""
|
|
|
|
def __init__(
|
|
self, channels: int, kernel_size: int, bias: bool = True
|
|
) -> None:
|
|
"""Construct a MaskedConvolutionModule object."""
|
|
super(MaskedConvolutionModule, self).__init__()
|
|
# kernerl_size should be a odd number for 'SAME' padding
|
|
assert (kernel_size - 1) % 2 == 0
|
|
|
|
self.pointwise_conv1 = nn.Conv1d(
|
|
channels,
|
|
2 * channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=bias,
|
|
)
|
|
self.depthwise_conv = nn.Conv1d(
|
|
channels,
|
|
channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=(kernel_size - 1) // 2,
|
|
groups=channels,
|
|
bias=bias,
|
|
)
|
|
self.norm = nn.LayerNorm(channels)
|
|
self.pointwise_conv2 = nn.Conv1d(
|
|
channels,
|
|
channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=bias,
|
|
)
|
|
self.activation = Swish()
|
|
|
|
def forward(self, x: Tensor, key_padding_mask: Optional[Tensor]) -> Tensor:
|
|
"""Compute convolution module.
|
|
|
|
Args:
|
|
x: Input tensor (T, N, C) == (#time, batch, channels).
|
|
key_padding_mask: if supplied, a Tensor with dtype=torch.Bool and
|
|
shape (N, T), with True for positions that correspond to
|
|
padding (and should be zeroed in convolutions).
|
|
|
|
Returns:
|
|
Tensor: Output tensor (T, N, C)
|
|
|
|
"""
|
|
# exchange the temporal dimension and the feature dimension
|
|
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
|
|
|
# GLU mechanism
|
|
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
|
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
|
|
|
# Logical-not key_padding_mask, unsqueeze to shape (N, 1, T) and convert
|
|
# to float. Then we can just multiply by it when we need to apply
|
|
# masking, i.e. prior to the convolution over time.
|
|
if key_padding_mask is not None:
|
|
x = x * torch.logical_not(key_padding_mask).unsqueeze(1).to(dtype=x.dtype)
|
|
|
|
# 1D Depthwise Conv
|
|
x = self.depthwise_conv(x)
|
|
x = self.activation(self.norm(x))
|
|
|
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
|
|
|
return x.permute(2, 0, 1) # (time, batch, channel)
|
|
|
|
|
|
class Swish(torch.nn.Module):
|
|
"""Construct an Swish object."""
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
"""Return Swich activation function."""
|
|
return x * torch.sigmoid(x)
|
|
|
|
|
|
|
|
class MaskedLmConformerEncoderLayer(nn.Module):
|
|
"""
|
|
MaskedLmConformerEncoderLayer is made up of self-attn, feedforward and convolution
|
|
networks. It's a simplified version of the conformer code we were previously
|
|
using, with pre-normalization hard-coded, relative positional encoding,
|
|
LayerNorm instead of BatchNorm in the convolution layers, and the key_padding_mask
|
|
applied also in the convolution layers so the computation is independent of
|
|
how the sequences are batched.
|
|
|
|
See: "Conformer: Convolution-augmented Transformer for Speech Recognition", for
|
|
the basic conformer.
|
|
|
|
Args:
|
|
d_model: the number of expected features in the input (required).
|
|
nhead: the number of heads in the multiheadattention models (required).
|
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
dropout: the dropout value (default=0.1).
|
|
cnn_module_kernel (int): Kernel size of convolution module.
|
|
|
|
Examples::
|
|
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
|
>>> src = torch.rand(10, 32, 512)
|
|
>>> pos_emb = torch.rand(32, 19, 512)
|
|
>>> out = encoder_layer(src, pos_emb)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
d_model: int,
|
|
nhead: int,
|
|
dim_feedforward: int = 2048,
|
|
dropout: float = 0.1,
|
|
cnn_module_kernel: int = 31,
|
|
) -> None:
|
|
super(ConformerEncoderLayer, self).__init__()
|
|
self.self_attn = RelPositionMultiheadAttention(
|
|
d_model, nhead, dropout=0.0
|
|
)
|
|
|
|
self.feed_forward = nn.Sequential(
|
|
nn.Linear(d_model, dim_feedforward),
|
|
Swish(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(dim_feedforward, d_model),
|
|
)
|
|
|
|
self.feed_forward_macaron = nn.Sequential(
|
|
nn.Linear(d_model, dim_feedforward),
|
|
Swish(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(dim_feedforward, d_model),
|
|
)
|
|
|
|
self.conv_module = MaskedConvolutionModule(d_model, cnn_module_kernel)
|
|
|
|
self.norm_ff_macaron = nn.LayerNorm(
|
|
d_model
|
|
) # for the macaron style FNN module
|
|
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
|
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
|
|
|
self.ff_scale = 0.5
|
|
|
|
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
|
self.norm_final = nn.LayerNorm(
|
|
d_model
|
|
) # for the final output of the block
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(
|
|
self,
|
|
x: Tensor,
|
|
pos_emb: Tensor,
|
|
attn_mask: Optional[Tensor] = None,
|
|
key_padding_mask: Optional[Tensor] = None,
|
|
) -> Tensor:
|
|
"""
|
|
Pass the input through the encoder layer.
|
|
|
|
Args:
|
|
x: the sequence to the encoder layer (required).
|
|
pos_emb: Positional embedding tensor (required).
|
|
attn_mask: the mask for the x sequence's attention to itself (optional);
|
|
of shape (T, T)
|
|
key_padding_mask: the mask for the src keys per batch (optional).
|
|
|
|
Shape:
|
|
x: (T, N, C) i.e. (seq_len, batch_size, num_channels)
|
|
pos_emb: (1, 2*T-1, C)
|
|
attn_mask: (T, T) or (N*num_heads, T, T), of dtype torch.bool or torch.float, where
|
|
the 1st S is interpreted as the target sequence (output) and the 2nd as the source
|
|
sequence (input).
|
|
key_padding_mask: (N, T), of dtype torch.bool
|
|
|
|
T is the sequence length, N is the batch size, C is the number of channels.
|
|
Return:
|
|
Returns x with something added to it, of shape (T, N, C)
|
|
"""
|
|
|
|
# macaron style feed forward module
|
|
residual = x
|
|
x = self.norm_ff_macaron(x)
|
|
x = residual + self.ff_scale * self.dropout(
|
|
self.feed_forward_macaron(x)
|
|
)
|
|
|
|
# multi-headed self-attention module
|
|
residual = x
|
|
x = self.norm_mha(x)
|
|
self_attn = self.self_attn(x, x, x,
|
|
pos_emb=pos_emb,
|
|
attn_mask=attn_mask,
|
|
key_padding_mask=key_padding_mask,
|
|
need_weights=False
|
|
)[0]
|
|
x = residual + self.dropout(self_attn)
|
|
|
|
# convolution module
|
|
residual = x
|
|
x = self.norm_conv(x)
|
|
|
|
x = residual + self.dropout(self.conv_module(x, key_padding_mask=key_padding_mask))
|
|
|
|
# feed forward module
|
|
residual = x
|
|
x = self.norm_ff(x)
|
|
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
|
|
|
x = self.norm_final(x)
|
|
|
|
return x
|
|
|
|
|
|
def _get_clones(module, N):
|
|
return ModuleList([copy.deepcopy(module) for i in range(N)])
|
|
|
|
class MaskedLmConformerEncoder(nn.Module):
|
|
r"""MaskedLmConformerEncoder is a stack of N encoder layers, modified from
|
|
torch.nn.TransformerEncoder. The only differences are some name
|
|
changes for parameters.
|
|
|
|
Args:
|
|
encoder_layer: an instance of the MaskedLmConformerEncoderLayer() class (required).
|
|
num_layers: the number of sub-encoder-layers in the encoder (required).
|
|
norm: the layer normalization component (optional).
|
|
|
|
Examples::
|
|
>>> encoder_layer = MaskedLmConformerEncoderLayer(d_model=512, nhead=8)
|
|
>>> conformer_encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=6)
|
|
>>> src = torch.rand(10, 32, 512)
|
|
>>> src, pos_emb = self.encoder_pos(src)
|
|
>>> out = conformer_encoder(src, pos_emb)
|
|
"""
|
|
__constants__ = ['norm']
|
|
|
|
def __init__(self, encoder_layer: nn.Module, num_layers: int,
|
|
norm: Optional[nn.Module] = None):
|
|
super(MaskedLmConformerEncoder, self).__init__()
|
|
self.layers = _get_clones(encoder_layer, num_layers)
|
|
self.num_layers = num_layers
|
|
self.norm = norm
|
|
|
|
|
|
def forward(
|
|
self,
|
|
x: Tensor,
|
|
pos_emb: Tensor,
|
|
attn_mask: Optional[Tensor] = None,
|
|
key_padding_mask: Optional[Tensor] = None,
|
|
) -> Tensor:
|
|
r"""Pass the input through the encoder layers in turn.
|
|
Args
|
|
x: input of shape (T, N, C), i.e. (seq_len, batch, channels)
|
|
pos_emb: positional embedding tensor of shape (1, 2*T-1, C),
|
|
attn_mask (optional, likely not used): mask for self-attention of
|
|
x to itself, of shape (T, T)
|
|
key_padding_mask (optional): mask of shape (N, T), dtype must be bool.
|
|
Returns:
|
|
Returns a tensor with the same shape as x, i.e. (T, N, C).
|
|
"""
|
|
for mod in self.layers:
|
|
x = mod(
|
|
x,
|
|
pos_emb,
|
|
attn_mask=attn_mask,
|
|
key_padding_mask=key_padding_mask,
|
|
)
|
|
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
|
|
return x
|
|
|
|
|
|
class RelPositionalEncoding(torch.nn.Module):
|
|
"""Relative positional encoding module.
|
|
|
|
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
|
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
|
|
|
|
Args:
|
|
d_model: Embedding dimension.
|
|
dropout_rate: Dropout rate.
|
|
max_len: Maximum input length.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self, d_model: int, dropout_rate: float, max_len: int = 5000
|
|
) -> None:
|
|
"""Construct an PositionalEncoding object."""
|
|
super(RelPositionalEncoding, self).__init__()
|
|
self.d_model = d_model
|
|
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
|
self.pe = None
|
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
|
|
|
def extend_pe(self, x: Tensor) -> None:
|
|
"""Reset the positional encodings."""
|
|
if self.pe is not None:
|
|
# self.pe contains both positive and negative parts
|
|
# the length of self.pe is 2 * input_len - 1
|
|
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
|
# Note: TorchScript doesn't implement operator== for torch.Device
|
|
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
|
x.device
|
|
):
|
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
|
return
|
|
# Suppose `i` means to the position of query vector and `j` means the
|
|
# position of key vector. We use position relative positions when keys
|
|
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
|
pe_positive = torch.zeros(x.size(1), self.d_model)
|
|
pe_negative = torch.zeros(x.size(1), self.d_model)
|
|
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
|
div_term = torch.exp(
|
|
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
|
* -(math.log(10000.0) / self.d_model)
|
|
)
|
|
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
|
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
|
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
|
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
|
|
|
# Reserve the order of positive indices and concat both positive and
|
|
# negative indices. This is used to support the shifting trick
|
|
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
|
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
|
pe_negative = pe_negative[1:].unsqueeze(0)
|
|
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
|
"""Add positional encoding.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor (batch, time, C).
|
|
|
|
Returns (x, pos_enc):
|
|
x: torch.Tensor: x itself, with dropout added: (batch, time, C).
|
|
pos_enc: torch.Tensor: Relative positional encoding as tensor of shape (1, 2*time-1, C).
|
|
|
|
"""
|
|
self.extend_pe(x)
|
|
pos_emb = self.pe[
|
|
:,
|
|
self.pe.size(1) // 2
|
|
- x.size(1)
|
|
+ 1 : self.pe.size(1) // 2
|
|
+ x.size(1),
|
|
]
|
|
return self.dropout(x), self.dropout(pos_emb)
|
|
|
|
|
|
class RelPositionMultiheadAttention(nn.Module):
|
|
r"""Multi-Head Attention layer with relative position encoding
|
|
|
|
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
|
|
|
Args:
|
|
embed_dim: total dimension of the model.
|
|
num_heads: parallel attention heads.
|
|
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
|
|
|
Examples::
|
|
|
|
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
|
|
>>> attn_output, attn_output_weights = rel_pos_multihead_attn(query, key, value, pos_emb)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
dropout: float = 0.0,
|
|
) -> None:
|
|
super(RelPositionMultiheadAttention, self).__init__()
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
self.dropout = dropout
|
|
self.head_dim = embed_dim // num_heads
|
|
assert (
|
|
self.head_dim * num_heads == self.embed_dim
|
|
), "embed_dim must be divisible by num_heads"
|
|
|
|
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
|
|
|
# linear transformation for positional encoding.
|
|
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
|
# these two learnable bias are used in matrix c and matrix d
|
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
|
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
|
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
|
|
|
self._reset_parameters()
|
|
|
|
def _reset_parameters(self) -> None:
|
|
nn.init.xavier_uniform_(self.in_proj.weight)
|
|
nn.init.constant_(self.in_proj.bias, 0.0)
|
|
nn.init.constant_(self.out_proj.bias, 0.0)
|
|
|
|
nn.init.xavier_uniform_(self.pos_bias_u)
|
|
nn.init.xavier_uniform_(self.pos_bias_v)
|
|
|
|
def forward(
|
|
self,
|
|
query: Tensor,
|
|
key: Tensor,
|
|
value: Tensor,
|
|
pos_emb: Tensor,
|
|
key_padding_mask: Optional[Tensor] = None,
|
|
need_weights: bool = True,
|
|
attn_mask: Optional[Tensor] = None,
|
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
r"""
|
|
Args (see below for shapes):
|
|
query, key, value: map a query and a set of key-value pairs to an output.
|
|
pos_emb: Positional embedding tensor
|
|
key_padding_mask: if provided, specified padding elements in the key will
|
|
be ignored by the attention. When given a binary mask and a value is True,
|
|
the corresponding value on the attention layer will be ignored. When given
|
|
a byte mask and a value is non-zero, the corresponding value on the attention
|
|
layer will be ignored
|
|
need_weights: if true, return (output, attn_output_weights); else, (output, None).
|
|
|
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
|
|
|
Shape:
|
|
- Inputs:
|
|
- query: :math:`(T, N, C)` where T is the output sequence length, N is the batch size, C is
|
|
the embedding dimension (number of channels).
|
|
- key: :math:`(S, N, C)`, where S is the input sequence length.
|
|
- value: :math:`(S, N, C)`
|
|
- pos_emb: :math:`(N, 2*T-1, C)` or :math:`(1, 2*T-1, C)`. Note: this assumes T == S, which it will be, but
|
|
still we use different letters because S relates to the input position, T to the
|
|
output posision.
|
|
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the input sequence length.
|
|
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
|
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
|
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
|
- attn_mask: 2D mask :math:`(T, S)` where T is the output sequence length, S is the input sequence length.
|
|
3D mask :math:`(N*num_heads, T, S)` where N is the batch size, where T is the output sequence length,
|
|
S is the input sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
|
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
|
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
|
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
|
is provided, it will be added to the attention weight.
|
|
|
|
Return:
|
|
(output, attn_output_weights) if need_weights==True, else (output, None), where:
|
|
|
|
- output: :math:`(T, N, C)` where T is the output sequence length, N is the batch size,
|
|
C is the embedding/channel dimension.
|
|
- attn_output_weights: :math:`(N, T, S)` where N is the batch size,
|
|
T is the output sequence length, S is the input sequence length (actually
|
|
S and T are the same number).
|
|
"""
|
|
return self.multi_head_attention_forward(
|
|
query,
|
|
key,
|
|
value,
|
|
pos_emb,
|
|
self.embed_dim,
|
|
self.num_heads,
|
|
self.in_proj.weight,
|
|
self.in_proj.bias,
|
|
self.dropout,
|
|
self.out_proj.weight,
|
|
self.out_proj.bias,
|
|
training=self.training,
|
|
key_padding_mask=key_padding_mask,
|
|
need_weights=need_weights,
|
|
attn_mask=attn_mask,
|
|
)
|
|
|
|
def rel_shift(self, x: Tensor) -> Tensor:
|
|
"""Compute relative positional encoding.
|
|
|
|
Args:
|
|
x: Input tensor (batch, head, time1, 2*time1-1).
|
|
time1 means the length of query vector.
|
|
|
|
Returns:
|
|
Tensor: tensor of shape (batch, head, time1, time2)
|
|
(note: time2 has the same value as time1, but it is for
|
|
the key, while time1 is for the query).
|
|
"""
|
|
(batch_size, num_heads, time1, n) = x.shape
|
|
assert n == 2 * time1 - 1
|
|
# Note: TorchScript requires explicit arg for stride()
|
|
batch_stride = x.stride(0)
|
|
head_stride = x.stride(1)
|
|
time1_stride = x.stride(2)
|
|
n_stride = x.stride(3)
|
|
return x.as_strided(
|
|
(batch_size, num_heads, time1, time1),
|
|
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
|
storage_offset=n_stride * (time1 - 1),
|
|
)
|
|
|
|
def multi_head_attention_forward(
|
|
self,
|
|
query: Tensor,
|
|
key: Tensor,
|
|
value: Tensor,
|
|
pos_emb: Tensor,
|
|
embed_dim_to_check: int,
|
|
num_heads: int,
|
|
in_proj_weight: Tensor,
|
|
in_proj_bias: Tensor,
|
|
dropout_p: float,
|
|
out_proj_weight: Tensor,
|
|
out_proj_bias: Tensor,
|
|
training: bool = True,
|
|
key_padding_mask: Optional[Tensor] = None,
|
|
need_weights: bool = True,
|
|
attn_mask: Optional[Tensor] = None,
|
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
r"""
|
|
Args:
|
|
query, key, value: map a query and a set of key-value pairs to an output.
|
|
pos_emb: Positional embedding tensor
|
|
embed_dim_to_check: total dimension of the model.
|
|
num_heads: parallel attention heads.
|
|
in_proj_weight, in_proj_bias: input projection weight and bias.
|
|
dropout_p: probability of an element to be zeroed.
|
|
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
|
training: apply dropout if is ``True``.
|
|
key_padding_mask: if provided, specified padding elements in the key will
|
|
be ignored by the attention. This is an binary mask. When the value is True,
|
|
the corresponding value on the attention layer will be filled with -inf.
|
|
need_weights: output attn_output_weights.
|
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
|
|
|
Shape:
|
|
Inputs:
|
|
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
|
the embedding dimension.
|
|
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
|
the embedding dimension.
|
|
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
|
the embedding dimension.
|
|
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
|
|
length, N is the batch size, E is the embedding dimension.
|
|
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
|
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
|
will be unchanged. If a BoolTensor is provided, the positions with the
|
|
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
|
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
|
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
|
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
|
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
|
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
|
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
|
is provided, it will be added to the attention weight.
|
|
|
|
Outputs:
|
|
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
|
E is the embedding dimension.
|
|
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
|
L is the target sequence length, S is the source sequence length.
|
|
"""
|
|
|
|
tgt_len, bsz, embed_dim = query.size()
|
|
assert embed_dim == embed_dim_to_check
|
|
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
|
|
|
head_dim = embed_dim // num_heads
|
|
assert (
|
|
head_dim * num_heads == embed_dim
|
|
), "embed_dim must be divisible by num_heads"
|
|
scaling = float(head_dim) ** -0.5
|
|
|
|
if torch.equal(query, key) and torch.equal(key, value):
|
|
# self-attention
|
|
q, k, v = nn.functional.linear(
|
|
query, in_proj_weight, in_proj_bias
|
|
).chunk(3, dim=-1)
|
|
|
|
elif torch.equal(key, value):
|
|
# encoder-decoder attention
|
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
_b = in_proj_bias
|
|
_start = 0
|
|
_end = embed_dim
|
|
_w = in_proj_weight[_start:_end, :]
|
|
if _b is not None:
|
|
_b = _b[_start:_end]
|
|
q = nn.functional.linear(query, _w, _b)
|
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
_b = in_proj_bias
|
|
_start = embed_dim
|
|
_end = None
|
|
_w = in_proj_weight[_start:, :]
|
|
if _b is not None:
|
|
_b = _b[_start:]
|
|
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
|
|
|
|
else:
|
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
_b = in_proj_bias
|
|
_start = 0
|
|
_end = embed_dim
|
|
_w = in_proj_weight[_start:_end, :]
|
|
if _b is not None:
|
|
_b = _b[_start:_end]
|
|
q = nn.functional.linear(query, _w, _b)
|
|
|
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
_b = in_proj_bias
|
|
_start = embed_dim
|
|
_end = embed_dim * 2
|
|
_w = in_proj_weight[_start:_end, :]
|
|
if _b is not None:
|
|
_b = _b[_start:_end]
|
|
k = nn.functional.linear(key, _w, _b)
|
|
|
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
_b = in_proj_bias
|
|
_start = embed_dim * 2
|
|
_end = None
|
|
_w = in_proj_weight[_start:, :]
|
|
if _b is not None:
|
|
_b = _b[_start:]
|
|
v = nn.functional.linear(value, _w, _b)
|
|
|
|
#if not self.is_espnet_structure:
|
|
# q = q * scaling
|
|
|
|
if attn_mask is not None:
|
|
assert (
|
|
attn_mask.dtype == torch.float32
|
|
or attn_mask.dtype == torch.float64
|
|
or attn_mask.dtype == torch.float16
|
|
or attn_mask.dtype == torch.uint8
|
|
or attn_mask.dtype == torch.bool
|
|
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
|
|
attn_mask.dtype
|
|
)
|
|
if attn_mask.dtype == torch.uint8:
|
|
warnings.warn(
|
|
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
|
|
)
|
|
attn_mask = attn_mask.to(torch.bool)
|
|
|
|
if attn_mask.dim() == 2:
|
|
attn_mask = attn_mask.unsqueeze(0)
|
|
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
|
raise RuntimeError(
|
|
"The size of the 2D attn_mask is not correct."
|
|
)
|
|
elif attn_mask.dim() == 3:
|
|
if list(attn_mask.size()) != [
|
|
bsz * num_heads,
|
|
query.size(0),
|
|
key.size(0),
|
|
]:
|
|
raise RuntimeError(
|
|
"The size of the 3D attn_mask is not correct."
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
"attn_mask's dimension {} is not supported".format(
|
|
attn_mask.dim()
|
|
)
|
|
)
|
|
# attn_mask's dim is 3 now.
|
|
|
|
# convert ByteTensor key_padding_mask to bool
|
|
if (
|
|
key_padding_mask is not None
|
|
and key_padding_mask.dtype == torch.uint8
|
|
):
|
|
warnings.warn(
|
|
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
|
)
|
|
key_padding_mask = key_padding_mask.to(torch.bool)
|
|
|
|
q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim)
|
|
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
|
|
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
|
|
|
src_len = k.size(0)
|
|
|
|
if key_padding_mask is not None:
|
|
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
|
key_padding_mask.size(0), bsz
|
|
)
|
|
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
|
|
key_padding_mask.size(1), src_len
|
|
)
|
|
|
|
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
|
|
|
pos_emb_bsz = pos_emb.size(0)
|
|
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
|
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
|
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
|
|
|
q_with_bias_u = (q + self.pos_bias_u).transpose(
|
|
1, 2
|
|
) # (batch, head, time1, d_k)
|
|
|
|
q_with_bias_v = (q + self.pos_bias_v).transpose(
|
|
1, 2
|
|
) # (batch, head, time1, d_k)
|
|
|
|
# compute attention score
|
|
# first compute matrix a and matrix c
|
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
|
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
|
matrix_ac = torch.matmul(
|
|
q_with_bias_u, k
|
|
) # (batch, head, time1, time2)
|
|
|
|
# compute matrix b and matrix d
|
|
matrix_bd = torch.matmul(
|
|
q_with_bias_v, p.transpose(-2, -1)
|
|
) # (batch, head, time1, 2*time1-1)
|
|
matrix_bd = self.rel_shift(matrix_bd)
|
|
|
|
#if not self.is_espnet_structure:
|
|
# attn_output_weights = (
|
|
# matrix_ac + matrix_bd
|
|
# ) # (batch, head, time1, time2)
|
|
#else:
|
|
|
|
attn_output_weights = (
|
|
matrix_ac + matrix_bd
|
|
) * scaling # (batch, head, time1, time2)
|
|
|
|
attn_output_weights = attn_output_weights.view(
|
|
bsz * num_heads, tgt_len, -1
|
|
)
|
|
|
|
assert list(attn_output_weights.size()) == [
|
|
bsz * num_heads,
|
|
tgt_len,
|
|
src_len,
|
|
]
|
|
|
|
if attn_mask is not None:
|
|
if attn_mask.dtype == torch.bool:
|
|
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
|
|
else:
|
|
attn_output_weights += attn_mask
|
|
|
|
if key_padding_mask is not None:
|
|
attn_output_weights = attn_output_weights.view(
|
|
bsz, num_heads, tgt_len, src_len
|
|
)
|
|
attn_output_weights = attn_output_weights.masked_fill(
|
|
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
|
float("-inf"),
|
|
)
|
|
attn_output_weights = attn_output_weights.view(
|
|
bsz * num_heads, tgt_len, src_len
|
|
)
|
|
|
|
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
|
attn_output_weights = nn.functional.dropout(
|
|
attn_output_weights, p=dropout_p, training=training
|
|
)
|
|
|
|
attn_output = torch.bmm(attn_output_weights, v)
|
|
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
|
attn_output = (
|
|
attn_output.transpose(0, 1)
|
|
.contiguous()
|
|
.view(tgt_len, bsz, embed_dim)
|
|
)
|
|
attn_output = nn.functional.linear(
|
|
attn_output, out_proj_weight, out_proj_bias
|
|
)
|
|
|
|
if need_weights:
|
|
# average attention weights over heads
|
|
attn_output_weights = attn_output_weights.view(
|
|
bsz, num_heads, tgt_len, src_len
|
|
)
|
|
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
|
else:
|
|
return attn_output, None
|