diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index c727da341..9f0db2e81 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -1,36 +1,21 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) # Apache 2.0 import math -import warnings -from typing import Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch -from torch import Tensor, nn -from transformer import Supervisions, Transformer, encoder_padding_mask +import torch.nn as nn +from torch import Tensor +from torch.nn.utils.rnn import pad_sequence + +# Note: TorchScript requires Dict/List/etc. to be fully typed. +Supervisions = Dict[str, torch.Tensor] -class Conformer(Transformer): - """ - Args: - num_features (int): Number of input features - num_classes (int): Number of output classes - subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension - nhead (int): number of head - dim_feedforward (int): feedforward dimention - num_encoder_layers (int): number of encoder layers - num_decoder_layers (int): number of decoder layers - dropout (float): dropout rate - cnn_module_kernel (int): Kernel size of convolution module - normalize_before (bool): whether to use layer_norm before the first block. - """ - +class MaskedLmConformer(nn.Module): def __init__( self, - num_features: int, num_classes: int, d_model: int = 256, nhead: int = 4, @@ -39,81 +24,766 @@ class Conformer(Transformer): num_decoder_layers: int = 6, dropout: float = 0.1, cnn_module_kernel: int = 31, - normalize_before: bool = True, - is_espnet_structure: bool = False, ) -> None: - super(Conformer, self).__init__( - num_features=num_features, - num_classes=num_classes, - subsampling_factor=subsampling_factor, - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - num_encoder_layers=num_encoder_layers, - num_decoder_layers=num_decoder_layers, - dropout=dropout, - normalize_before=normalize_before, + """ + Args: + num_classes: + The input and output dimension of the model (inputs and outputs are + both discrete) + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder/decoder. + num_encoder_layers: + Number of encoder layers. + num_decoder_layers: + Number of decoder layers. + dropout: + Dropout in encoder/decoder. + """ + super(MaskedLmConformer, self).__init__() + + self.num_classes = num_classes + + # self.embed is the embedding used for both the encoder and decoder. + self.embed_scale = d_model ** 0.5 + self.embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model, + _weight=torch.randn(self.decoder_num_class, d_model) * (1 / self.embed_scale) ) self.encoder_pos = RelPositionalEncoding(d_model, dropout) - encoder_layer = ConformerEncoderLayer( + encoder_layer = MaskedLmConformerEncoderLayer( d_model, nhead, dim_feedforward, dropout, cnn_module_kernel, - normalize_before, - is_espnet_structure, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - self.normalize_before = normalize_before - self.is_espnet_structure = is_espnet_structure - if self.normalize_before and self.is_espnet_structure: - self.after_norm = nn.LayerNorm(d_model) - else: - # Note: TorchScript detects that self.after_norm could be used inside forward() - # and throws an error without this change. - self.after_norm = identity + self.encoder = MaskedLmConformerEncoder(encoder_layer, num_encoder_layers, + norm=nn.LayerNorm(d_model)) - def run_encoder( - self, x: Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[Tensor, Optional[Tensor]]: + if num_decoder_layers > 0: + self.decoder_num_class = self.num_classes + + decoder_layer = TransformerDecoderLayerRelPos( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + ) + + # Projects the embedding of `src`, to be added to `memory` + self.src_linear = torch.nn.Linear(d_model, d_model) + + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoderRelPos( + decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + norm=decoder_norm, + ) + + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) + + + def forward( + self, + masked_src_symbols: torch.Tensor, + key_padding_mask: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: - x: - The model input. Its shape is [N, T, C]. - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute encoder padding mask, which is used as memory key padding - mask for the decoder. + masked_src_symbols: + The input symbols to be embedded (will actually have query positions + masked), as a Tensor of shape (batch_size, seq_len) and dtype=torch.int64. + I.e. shape (N, T) + key_padding_mask: + Either None, or a Tensor of shape (batch_size, seq_len) i.e. (N, T), + and dtype=torch.bool which has True in positions to be masked in attention + layers and convolutions because they represent padding at the ends of + sequences. + Returns: - Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). - Tensor: Mask tensor of dimension (batch_size, input_length) + Returns (encoded, pos_emb), where: + `encoded` is a Tensor containing the encoded data; it is of shape (N, T, C) + where C is the embedding_dim. + `pos_emb` is a Tensor containing the relative positional encoding, of + shape (1, 2*T-1, C) """ - x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - mask = encoder_padding_mask(x.size(0), supervisions) - if mask is not None: - mask = mask.to(x.device) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) - if self.normalize_before and self.is_espnet_structure: - x = self.after_norm(x) + x = self.embed(masked_src_symbols) * self.embed_scale # (N, T, C) + x, pos_emb = self.encoder_pos(x) # pos_emb: (1, 2*T-1, C) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - return x, mask + x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C) + + return x, pos_emb + + def decoder_nll( + self, + memory: torch.Tensor, + pos_emb: torch.Tensor, + src_symbols: torch.Tensor, + tgt_symbols: torch.Tensor, + key_padding_mask: torch.Tensor + ) -> torch.Tensor: + """ + Args: + memory: + The output of the encoder, with shape (T, N, C) + pos_emb: + Relative positional embedding, of shape (1, 2*T-1, C), as + returned from the encoder + src_symbols: + The un-masked src symbols, a LongTensor of shape (N, T). + Can be used to predict the target + only in a left-to-right manner (otherwise it's cheating). + tgt_symbols: + Target symbols, a LongTensor of shape (N, T). + The same as src_symbols, but shifted by one (and also, + without symbol randomization, see randomize_proportion + in dataloader) + key_padding_mask: + A BoolTensor of shape (N, T), with True for positions + that correspond to padding at the end of source and + memory sequences. The same mask is used for self-attention + and cross-attention, since the padding is the same. + + Returns: + Returns a tensor of shape (N, T), containing the negative + log-probabilities for the target symbols at each position + in the target sequence. + """ + (T, N, C) = memory.shape + + tgt_mask = generate_square_subsequent_mask(T, memory.device) + + src = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C) + src = src.permute(1, 0, 2) # (N, T, C) -> (T, N, C) -class ConformerEncoderLayer(nn.Module): + src = memory + self.src_linear(src) # (T, N, C) + + # This is a little confusing, how "tgt" is set to src. "src" is the + # symbol sequence without masking but with padding and randomization. + # "tgt" is like "src" but shifted by one. + pred = self.decoder( + tgt=src, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=key_padding_mask, + memory_key_padding_mask=key_padding_mask, + ) # (T, N, C) + + pred = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred = self.decoder_output_layer(pred) # (N, T, C) + + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred.view(-1, self.decoder_num_class), + tgt_symbols.view(-1), + reduction="none", + ) + nll = nll.view(N, T) + return nll + + + + +class TransformerDecoderRelPos(nn.Module): + r"""TransformerDecoderRelPos is a stack of N decoder layers. + This is modified from nn.TransformerDecoder to support relative positional + encoding. + + Args: + decoder_layer: an instance of the TransformerDecoderLayerRelPos() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8) + >>> transformer_decoder = nn.TransformerDecoderRelPos(decoder_layer, num_layers=6) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> pos_enc = torch.rand() + >>> out = transformer_decoder(tgt, memory) """ - ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. - See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm=None): + super(TransformerDecoderRelPos, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, x: Tensor, + pos_emb: Tensor, + memory: Tensor, + attn_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layer in turn. + + Args: + x: the input embedding sequence to the decoder (required): shape = (T, N, C). + Will be an embedding of `src_symbols` in practice + pos_emb: + A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels, + representing the relative positional encoding. + memory: the sequence from the last layer of the encoder (required): + shape = (T, N, C) + attn_mask: the mask for the `x` sequence's attention to itself, + of shape (T, T); in practice, will ensure that no + position can attend to later positions. A torch.Tensor with dtype=torch.float + or dtype=torch.bool. + key_padding_mask: the key-padding mask for both the memory and x sequences, + a torch.Tensor with dtype=bool and shape (N, T): true for masked + positions after the ends of sequences. + """ + + for mod in self.layers: + x = mod(x, pos_emb, memory, x_mask=x_mask, + key_padding_mask=key_padding_mask) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoderLayerRelPos(nn.Module): + """ + Modified from torch.nn.TransformerDecoderLayer. + Add it to use normalize_before (hardcoded to True), i.e. use layer_norm before the first block; + to use relative positional encoding; and for some changes/simplifications in interface + because both sequences are the same length and have the same mask. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> pos_emb = torch.rand(1, 20*2+1, 512) + >>> out = decoder_layer(tgt, pos_emb, memory) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + ) -> None: + super(TransformerDecoderLayer, self).__init__() + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerDecoderLayer, self).__setstate__(state) + + def forward( + self, + x: torch.Tensor, + pos_emb: torch.Tensor, + memory: torch.Tensor, + x_mask: Optional[torch.Tensor] = None, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Pass the inputs (and mask) through the decoder layer. + + Args: + x + The input embedding, to be added to by the forward function, of shape (T, N, C). + Attention within x will be left-to-right only (causal), thanks to x_mask. + pos_emb: + A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels, + containing the relative positional encoding. + memory: + the sequence from the last layer of the encoder (required). Shape = (T, N, C) + x_mask: + the mask for the x, to enforce causal (left to right) attention (optional). + Shape == (T, T); may be bool or float. The first T pertains to the output, + the second T to the input. + key_padding_mask: + the key-padding mask to use for both the x and memory sequences. Shep == (N, T); + may be bool (True==masked) or float (to be added to attention scores). + + Returns: + Returns 'x plus something', a torch.Tensor with dtype the same as x (e.g. float), + and shape (T, N, C). + """ + residual = x + x = self.norm1(x) + self_attn = self.self_attn(x, x, x, + key_padding_mask=key_padding_mask, + need_weights=False, + attn_mask=x_mask, + )[0] + x = residual + self.dropout1(self_attn) + + residual = x + x = self.norm2(x) + src_attn = self.src_attn(x, memory, memory, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + x = residual + self.dropout2(src_attn) + + residual = x + x = self.norm3(x) + ff = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = residual + self.dropout3(ff) + return x + + +def _get_activation_fn(activation: str): + if activation == "relu": + return nn.functional.relu + elif activation == "gelu": + return nn.functional.gelu + + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) + + +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note:: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) + """ + + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. + """ + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + self.pe = None + + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is [1, T1, d_model]. The shape of the input x + is [N, T, d_model]. If T > T1, then we change the shape of self.pe + to [N, T, d_model]. Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape [N, T, C]. + Returns: + Return None. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + # Now pe is of shape [1, T, d_model], where T is x.size(1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding. + + Args: + x: + Its shape is [N, T, C] + + Returns: + Return a tensor of shape [N, T, C] + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1), :] + return self.dropout(x) + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + +class LabelSmoothingLoss(nn.Module): + """ + Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w) + and p_{prob. computed by model}(w) is minimized. + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa + + Args: + size: the number of class + padding_idx: padding_idx: ignored class id + smoothing: smoothing rate (0.0 means the conventional CE) + normalize_length: normalize loss by sequence length if True + criterion: loss function to be smoothed + """ + + def __init__( + self, + size: int, + padding_idx: int = -1, + smoothing: float = 0.1, + normalize_length: bool = False, + criterion: nn.Module = nn.KLDivLoss(reduction="none"), + ) -> None: + """Construct an LabelSmoothingLoss object.""" + super(LabelSmoothingLoss, self).__init__() + self.criterion = criterion + self.padding_idx = padding_idx + assert 0.0 < smoothing <= 1.0 + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.size = size + self.true_dist = None + self.normalize_length = normalize_length + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute loss between x and target. + + Args: + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.padding_id of + dimension (batch_size, input_length). + + Returns: + A scalar tensor containing the loss without normalization. + """ + assert x.size(2) == self.size + # batch_size = x.size(0) + x = x.view(-1, self.size) + target = target.view(-1) + with torch.no_grad(): + true_dist = x.clone() + true_dist.fill_(self.smoothing / (self.size - 1)) + ignore = target == self.padding_idx # (B,) + total = len(target) - ignore.sum().item() + target = target.masked_fill(ignore, 0) # avoid -1 index + true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) + # denom = total if self.normalize_length else batch_size + denom = total if self.normalize_length else 1 + return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom + + + +def generate_square_subsequent_mask(sz: int, device: torch.device = torch.device('cpu')) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + + For instance, if sz is 3, it returns:: + + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) + + Args: + sz: mask size + + Returns: + A square mask of dimension (sz, sz) + """ + mask = (torch.triu(torch.ones(sz, sz, device=device)) == 1).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask + + +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + sos_id: + The ID of the SOS token. + + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. + """ + ans = [] + for utt in token_ids: + ans.append([sos_id] + utt) + return ans + + +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + eos_id: + The ID of the EOS token. + + Return: + Return a new list-of-list, where each sublist ends + with EOS ID. + """ + ans = [] + for utt in token_ids: + ans.append(utt + [eos_id]) + return ans + + + +class MaskedConvolutionModule(nn.Module): + """ + This is used in the MaskedLmConformerLayer. It is the same as the ConvolutionModule + of theConformer code, but with key_padding_mask supported to make the output independent + of the batching. + + Modified, ultimately, from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + """ + + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: + """Construct a MaskedConvolutionModule object.""" + super(MaskedConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.norm = nn.LayerNorm(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward(self, x: Tensor, key_padding_mask: Optional[Tensor]) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (T, N, C) == (#time, batch, channels). + key_padding_mask: if supplied, a Tensor with dtype=torch.Bool and + shape (N, T), with True for positions that correspond to + padding (and should be zeroed in convolutions). + + Returns: + Tensor: Output tensor (T, N, C) + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # Logical-not key_padding_mask, unsqueeze to shape (N, 1, T) and convert + # to float. Then we can just multiply by it when we need to apply + # masking, i.e. prior to the convolution over time. + if key_padding_mask is not None: + x = x * torch.logical_not(key_padding_mask).unsqueeze(1).to(dtype=x.dtype) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) # (time, batch, channel) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + + +class MaskedLmConformerEncoderLayer(nn.Module): + """ + MaskedLmConformerEncoderLayer is made up of self-attn, feedforward and convolution + networks. It's a simplified version of the conformer code we were previously + using, with pre-normalization hard-coded, relative positional encoding, + LayerNorm instead of BatchNorm in the convolution layers, and the key_padding_mask + applied also in the convolution layers so the computation is independent of + how the sequences are batched. + + See: "Conformer: Convolution-augmented Transformer for Speech Recognition", for + the basic conformer. Args: d_model: the number of expected features in the input (required). @@ -121,7 +791,6 @@ class ConformerEncoderLayer(nn.Module): dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). cnn_module_kernel (int): Kernel size of convolution module. - normalize_before: whether to use layer_norm before the first block. Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -137,12 +806,10 @@ class ConformerEncoderLayer(nn.Module): dim_feedforward: int = 2048, dropout: float = 0.1, cnn_module_kernel: int = 31, - normalize_before: bool = True, - is_espnet_structure: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure + d_model, nhead, dropout=0.0 ) self.feed_forward = nn.Sequential( @@ -159,7 +826,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(dim_feedforward, d_model), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.conv_module = MaskedConvolutionModule(d_model, cnn_module_kernel) self.norm_ff_macaron = nn.LayerNorm( d_model @@ -176,140 +843,129 @@ class ConformerEncoderLayer(nn.Module): self.dropout = nn.Dropout(dropout) - self.normalize_before = normalize_before - def forward( self, - src: Tensor, + x: Tensor, pos_emb: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """ Pass the input through the encoder layer. Args: - src: the sequence to the encoder layer (required). + x: the sequence to the encoder layer (required). pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). + attn_mask: the mask for the x sequence's attention to itself (optional); + of shape (T, T) + key_padding_mask: the mask for the src keys per batch (optional). Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number + x: (T, N, C) i.e. (seq_len, batch_size, num_channels) + pos_emb: (1, 2*T-1, C) + attn_mask: (T, T) or (N*num_heads, T, T), of dtype torch.bool or torch.float, where + the 1st S is interpreted as the target sequence (output) and the 2nd as the source + sequence (input). + key_padding_mask: (N, T), of dtype torch.bool + + T is the sequence length, N is the batch size, C is the number of channels. + Return: + Returns x with something added to it, of shape (T, N, C) """ # macaron style feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) + residual = x + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(x) ) - if not self.normalize_before: - src = self.norm_ff_macaron(src) # multi-headed self-attention module - residual = src - if self.normalize_before: - src = self.norm_mha(src) - src_att = self.self_attn( - src, - src, - src, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, + residual = x + x = self.norm_mha(x) + self_attn = self.self_attn(x, x, x, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False )[0] - src = residual + self.dropout(src_att) - if not self.normalize_before: - src = self.norm_mha(src) + x = residual + self.dropout(self_attn) # convolution module - residual = src - if self.normalize_before: - src = self.norm_conv(src) - src = residual + self.dropout(self.conv_module(src)) - if not self.normalize_before: - src = self.norm_conv(src) + residual = x + x = self.norm_conv(x) + + x = residual + self.dropout(self.conv_module(x, key_padding_mask=key_padding_mask)) # feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) - if not self.normalize_before: - src = self.norm_ff(src) + residual = x + x = self.norm_ff(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) - if self.normalize_before: - src = self.norm_final(src) + x = self.norm_final(x) - return src + return x -class ConformerEncoder(nn.TransformerEncoder): - r"""ConformerEncoder is a stack of N encoder layers +def _get_clones(module, N): + return ModuleList([copy.deepcopy(module) for i in range(N)]) + +class MaskedLmConformerEncoder(nn.Module): + r"""MaskedLmConformerEncoder is a stack of N encoder layers, modified from + torch.nn.TransformerEncoder. The only differences are some name + changes for parameters. Args: - encoder_layer: an instance of the ConformerEncoderLayer() class (required). + encoder_layer: an instance of the MaskedLmConformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). norm: the layer normalization component (optional). Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> encoder_layer = MaskedLmConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) + >>> src, pos_emb = self.encoder_pos(src) >>> out = conformer_encoder(src, pos_emb) """ + __constants__ = ['norm'] + + def __init__(self, encoder_layer: nn.Module, num_layers: int, + norm: Optional[nn.Module] = None): + super(MaskedLmConformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm - def __init__( - self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None - ) -> None: - super(ConformerEncoder, self).__init__( - encoder_layer=encoder_layer, num_layers=num_layers, norm=norm - ) def forward( self, - src: Tensor, + x: Tensor, pos_emb: Tensor, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - + Args + x: input of shape (T, N, C), i.e. (seq_len, batch, channels) + pos_emb: positional embedding tensor of shape (1, 2*T-1, C), + attn_mask (optional, likely not used): mask for self-attention of + x to itself, of shape (T, T) + key_padding_mask (optional): mask of shape (N, T), dtype must be bool. + Returns: + Returns a tensor with the same shape as x, i.e. (T, N, C). """ - output = src - for mod in self.layers: - output = mod( - output, + x = mod( + x, pos_emb, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, ) if self.norm is not None: - output = self.norm(output) + x = self.norm(x) - return output + return x class RelPositionalEncoding(torch.nn.Module): @@ -331,7 +987,6 @@ class RelPositionalEncoding(torch.nn.Module): """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model - self.xscale = math.sqrt(self.d_model) self.dropout = torch.nn.Dropout(p=dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) @@ -348,7 +1003,7 @@ class RelPositionalEncoding(torch.nn.Module): ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return - # Suppose `i` means to the position of query vecotr and `j` means the + # Suppose `i` means to the position of query vector and `j` means the # position of key vector. We use position relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + >>> attn_output, attn_output_weights = rel_pos_multihead_attn(query, key, value, pos_emb) """ def __init__( @@ -415,7 +1069,6 @@ class RelPositionMultiheadAttention(nn.Module): embed_dim: int, num_heads: int, dropout: float = 0.0, - is_espnet_structure: bool = False, ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim @@ -438,8 +1091,6 @@ class RelPositionMultiheadAttention(nn.Module): self._reset_parameters() - self.is_espnet_structure = is_espnet_structure - def _reset_parameters(self) -> None: nn.init.xavier_uniform_(self.in_proj.weight) nn.init.constant_(self.in_proj.bias, 0.0) @@ -459,7 +1110,7 @@ class RelPositionMultiheadAttention(nn.Module): attn_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" - Args: + Args (see below for shapes): query, key, value: map a query and a set of key-value pairs to an output. pos_emb: Positional embedding tensor key_padding_mask: if provided, specified padding elements in the key will @@ -467,37 +1118,40 @@ class RelPositionMultiheadAttention(nn.Module): the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored - need_weights: output attn_output_weights. + need_weights: if true, return (output, attn_output_weights); else, (output, None). + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. Shape: - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + - query: :math:`(T, N, C)` where T is the output sequence length, N is the batch size, C is + the embedding dimension (number of channels). + - key: :math:`(S, N, C)`, where S is the input sequence length. + - value: :math:`(S, N, C)` + - pos_emb: :math:`(N, 2*T-1, C)` or :math:`(1, 2*T-1, C)`. Note: this assumes T == S, which it will be, but + still we use different letters because S relates to the input position, T to the + output posision. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the input sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + - attn_mask: 2D mask :math:`(T, S)` where T is the output sequence length, S is the input sequence length. + 3D mask :math:`(N*num_heads, T, S)` where N is the batch size, where T is the output sequence length, + S is the input sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. + Return: + (output, attn_output_weights) if need_weights==True, else (output, None), where: + + - output: :math:`(T, N, C)` where T is the output sequence length, N is the batch size, + C is the embedding/channel dimension. + - attn_output_weights: :math:`(N, T, S)` where N is the batch size, + T is the output sequence length, S is the input sequence length (actually + S and T are the same number). """ return self.multi_head_attention_forward( query, @@ -669,8 +1323,8 @@ class RelPositionMultiheadAttention(nn.Module): _b = _b[_start:] v = nn.functional.linear(value, _w, _b) - if not self.is_espnet_structure: - q = q * scaling + #if not self.is_espnet_structure: + # q = q * scaling if attn_mask is not None: assert ( @@ -764,14 +1418,15 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - if not self.is_espnet_structure: - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) - else: - attn_output_weights = ( - matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) + #if not self.is_espnet_structure: + # attn_output_weights = ( + # matrix_ac + matrix_bd + # ) # (batch, head, time1, time2) + #else: + + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, -1 @@ -825,88 +1480,3 @@ class RelPositionMultiheadAttention(nn.Module): return attn_output, attn_output_weights.sum(dim=1) / num_heads else: return attn_output, None - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Conformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: - """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.depthwise_conv = nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=(kernel_size - 1) // 2, - groups=channels, - bias=bias, - ) - self.norm = nn.LayerNorm(channels) - self.pointwise_conv2 = nn.Conv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.activation = Swish() - - def forward(self, x: Tensor) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # 1D Depthwise Conv - x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1) - - -class Swish(torch.nn.Module): - """Construct an Swish object.""" - - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x) - - -def identity(x): - return x diff --git a/egs/librispeech/ASR/conformer_lm/test_conformer.py b/egs/librispeech/ASR/conformer_lm/test_conformer.py new file mode 100644 index 000000000..8aaae4277 --- /dev/null +++ b/egs/librispeech/ASR/conformer_lm/test_conformer.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# run with: +# python3 -m pytest test_conformer.py + +import torch +from conformer import ( + TransformerDecoderRelPos, + MaskedLmConformer, + RelPositionMultiheadAttention, + RelPositionalEncoding, + generate_square_subsequent_mask, +) + +from torch.nn.utils.rnn import pad_sequence + + +def test_rel_position_multihead_attention(): + # Also tests RelPositionalEncoding + embed_dim = 256 + num_heads = 4 + T = 25 + N = 4 + C = 256 + pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0) + rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + + x = torch.randn(N, T, C) + #pos_emb = torch.randn(1, 2*T-1, C) + x, pos_enc = pos_emb_module(x) + print("pos_enc.shape=", pos_enc.shape) + x = x.transpose(0, 1) # (T, N, C) + attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_enc) + + +def test_transformer(): + return + num_features = 40 + num_classes = 87 + model = Transformer(num_features=num_features, num_classes=num_classes) + + N = 31 + + for T in range(7, 30): + x = torch.rand(N, T, num_features) + y, _, _ = model(x) + assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) + + +def test_generate_square_subsequent_mask(): + s = 5 + mask = generate_square_subsequent_mask(s, torch.device('cpu')) + inf = float("inf") + expected_mask = torch.tensor( + [ + [0.0, -inf, -inf, -inf, -inf], + [0.0, 0.0, -inf, -inf, -inf], + [0.0, 0.0, 0.0, -inf, -inf], + [0.0, 0.0, 0.0, 0.0, -inf], + [0.0, 0.0, 0.0, 0.0, 0.0], + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) diff --git a/egs/librispeech/ASR/conformer_lm/test_transformer.py b/egs/librispeech/ASR/conformer_lm/test_transformer.py deleted file mode 100644 index 08e680607..000000000 --- a/egs/librispeech/ASR/conformer_lm/test_transformer.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python3 - -import torch -from transformer import ( - Transformer, - encoder_padding_mask, - generate_square_subsequent_mask, - decoder_padding_mask, - add_sos, - add_eos, -) - -from torch.nn.utils.rnn import pad_sequence - - -def test_encoder_padding_mask(): - supervisions = { - "sequence_idx": torch.tensor([0, 1, 2]), - "start_frame": torch.tensor([0, 0, 0]), - "num_frames": torch.tensor([18, 7, 13]), - } - - max_len = ((18 - 1) // 2 - 1) // 2 - mask = encoder_padding_mask(max_len, supervisions) - expected_mask = torch.tensor( - [ - [False, False, False], # ((18 - 1)//2 - 1)//2 = 3, - [False, True, True], # ((7 - 1)//2 - 1)//2 = 1, - [False, False, True], # ((13 - 1)//2 - 1)//2 = 2, - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_transformer(): - num_features = 40 - num_classes = 87 - model = Transformer(num_features=num_features, num_classes=num_classes) - - N = 31 - - for T in range(7, 30): - x = torch.rand(N, T, num_features) - y, _, _ = model(x) - assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) - - -def test_generate_square_subsequent_mask(): - s = 5 - mask = generate_square_subsequent_mask(s) - inf = float("inf") - expected_mask = torch.tensor( - [ - [0.0, -inf, -inf, -inf, -inf], - [0.0, 0.0, -inf, -inf, -inf], - [0.0, 0.0, 0.0, -inf, -inf], - [0.0, 0.0, 0.0, 0.0, -inf], - [0.0, 0.0, 0.0, 0.0, 0.0], - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_decoder_padding_mask(): - x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])] - y = pad_sequence(x, batch_first=True, padding_value=-1) - mask = decoder_padding_mask(y, ignore_id=-1) - expected_mask = torch.tensor( - [ - [False, False, True], - [False, True, True], - [False, False, False], - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_add_sos(): - x = [[1, 2], [3], [2, 5, 8]] - y = add_sos(x, sos_id=0) - expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]] - assert y == expected_y - - -def test_add_eos(): - x = [[1, 2], [3], [2, 5, 8]] - y = add_eos(x, eos_id=0) - expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]] - assert y == expected_y diff --git a/egs/librispeech/ASR/conformer_lm/transformer.py b/egs/librispeech/ASR/conformer_lm/transformer.py deleted file mode 100644 index 4367808a8..000000000 --- a/egs/librispeech/ASR/conformer_lm/transformer.py +++ /dev/null @@ -1,1501 +0,0 @@ -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# Apache 2.0 - -import math -from typing import Dict, List, Optional, Tuple - -import torch -import torch.nn as nn -from torch.nn.utils.rnn import pad_sequence - -# Note: TorchScript requires Dict/List/etc. to be fully typed. -Supervisions = Dict[str, torch.Tensor] - - -class MaskedLmConformer(nn.Module): - def __init__( - self, - num_classes: int, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - num_decoder_layers: int = 6, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - ) -> None: - """ - Args: - num_classes: - The input and output dimension of the model (inputs and outputs are - both discrete) - d_model: - Attention dimension. - nhead: - Number of heads in multi-head attention. - Must satisfy d_model // nhead == 0. - dim_feedforward: - The output dimension of the feedforward layers in encoder/decoder. - num_encoder_layers: - Number of encoder layers. - num_decoder_layers: - Number of decoder layers. - dropout: - Dropout in encoder/decoder. - """ - super(MaskedLmConformer, self).__init__() - - self.num_classes = num_classes - - # self.embed is the embedding used for both the encoder and decoder. - self.embed_scale = d_model ** 0.5 - self.embed = nn.Embedding( - num_embeddings=self.decoder_num_class, embedding_dim=d_model, - _weight=torch.randn(self.decoder_num_class, d_model) * (1 / self.embed_scale) - ) - - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - - encoder_layer = MaskedLmConformerEncoderLayer( - d_model, - nhead, - dim_feedforward, - dropout, - cnn_module_kernel, - ) - self.encoder = MaskedLmConformerEncoder(encoder_layer, num_encoder_layers, - norm=nn.LayerNorm(d_model)) - - if num_decoder_layers > 0: - self.decoder_num_class = self.num_classes - - decoder_layer = TransformerDecoderLayerRelPos( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - ) - - # Projects the embedding of `src`, to be added to `memory` - self.src_linear = torch.nn.Linear(d_model, d_model) - - decoder_norm = nn.LayerNorm(d_model) - self.decoder = TransformerDecoderRelPos( - decoder_layer=decoder_layer, - num_layers=num_decoder_layers, - norm=decoder_norm, - ) - - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) - - - def forward( - self, - masked_src_symbols: torch.Tensor, - key_padding_mask: torch.Tensor = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - masked_src_symbols: - The input symbols to be embedded (will actually have query positions - masked), as a Tensor of shape (batch_size, seq_len) and dtype=torch.int64. - I.e. shape (N, T) - key_padding_mask: - Either None, or a Tensor of shape (batch_size, seq_len) i.e. (N, T), - and dtype=torch.bool which has True in positions to be masked in attention - layers and convolutions because they represent padding at the ends of - sequences. - - - Returns: - Returns (encoded, pos_emb), where: - `encoded` is a Tensor containing the encoded data; it is of shape (N, T, C) - where C is the embedding_dim. - `pos_emb` is a Tensor containing the relative positional encoding, of - shape (1, 2*T-1, C) - """ - - x = self.embed(masked_src_symbols) * self.embed_scale # (N, T, C) - x, pos_emb = self.encoder_pos(x) # pos_emb: (1, 2*T-1, C) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C) - - return x, pos_emb - - def decoder_nll( - self, - memory: torch.Tensor, - pos_emb: torch.Tensor, - src_symbols: torch.Tensor, - tgt_symbols: torch.Tensor, - key_padding_mask: torch.Tensor - ) -> torch.Tensor: - """ - Args: - memory: - The output of the encoder, with shape (T, N, C) - pos_emb: - Relative positional embedding, of shape (1, 2*T-1, C), as - returned from the encoder - src_symbols: - The un-masked src symbols, a LongTensor of shape (N, T). - Can be used to predict the target - only in a left-to-right manner (otherwise it's cheating). - tgt_symbols: - Target symbols, a LongTensor of shape (N, T). - The same as src_symbols, but shifted by one (and also, - without symbol randomization, see randomize_proportion - in dataloader) - key_padding_mask: - A BoolTensor of shape (N, T), with True for positions - that correspond to padding at the end of source and - memory sequences. The same mask is used for self-attention - and cross-attention, since the padding is the same. - - Returns: - Returns a tensor of shape (N, T), containing the negative - log-probabilities for the target symbols at each position - in the target sequence. - """ - (T, N, C) = memory.shape - - tgt_mask = generate_square_subsequent_mask(T, memory.device) - - src = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C) - src = src.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - - src = memory + self.src_linear(src) # (T, N, C) - - # This is a little confusing, how "tgt" is set to src. "src" is the - # symbol sequence without masking but with padding and randomization. - # "tgt" is like "src" but shifted by one. - pred = self.decoder( - tgt=src, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=key_padding_mask, - memory_key_padding_mask=key_padding_mask, - ) # (T, N, C) - - pred = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - pred = self.decoder_output_layer(pred) # (N, T, C) - - # nll: negative log-likelihood - nll = torch.nn.functional.cross_entropy( - pred.view(-1, self.decoder_num_class), - tgt_symbols.view(-1), - reduction="none", - ) - nll = nll.view(N, T) - return nll - - - - -class TransformerDecoderRelPos(Module): - r"""TransformerDecoderRelPos is a stack of N decoder layers. - This is modified from nn.TransformerDecoder to support relative positional - encoding. - - Args: - decoder_layer: an instance of the TransformerDecoderLayerRelPos() class (required). - num_layers: the number of sub-decoder-layers in the decoder (required). - norm: the layer normalization component (optional). - - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8) - >>> transformer_decoder = nn.TransformerDecoderRelPos(decoder_layer, num_layers=6) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> pos_enc = torch.rand() - >>> out = transformer_decoder(tgt, memory) - """ - __constants__ = ['norm'] - - def __init__(self, decoder_layer, num_layers, norm=None): - super(TransformerDecoder, self).__init__() - self.layers = _get_clones(decoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward(self, x: Tensor, - pos_emb: Tensor, - memory: Tensor, - attn_mask: Optional[Tensor] = None, - key_padding_mask: Optional[Tensor] = None) -> Tensor: - r"""Pass the inputs (and mask) through the decoder layer in turn. - - Args: - x: the input embedding sequence to the decoder (required): shape = (T, N, C). - Will be an embedding of `src_symbols` in practice - pos_emb: - A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels, - representing the relative positional encoding. - memory: the sequence from the last layer of the encoder (required): - shape = (T, N, C) - attn_mask: the mask for the `x` sequence's attention to itself, - of shape (T, T); in practice, will ensure that no - position can attend to later positions. A torch.Tensor with dtype=torch.float - or dtype=torch.bool. - key_padding_mask: the key-padding mask for both the memory and x sequences, - a torch.Tensor with dtype=bool and shape (N, T): true for masked - positions after the ends of sequences. - """ - - for mod in self.layers: - x = mod(x, pos_emb, memory, x_mask=x_mask, - key_padding_mask=key_padding_mask) - - if self.norm is not None: - output = self.norm(output) - - return output - - -class TransformerDecoderLayerRelPos(nn.Module): - """ - Modified from torch.nn.TransformerDecoderLayer. - Add it to use normalize_before (hardcoded to True), i.e. use layer_norm before the first block; - to use relative positional encoding; and for some changes/simplifications in interface - because both sequences are the same length and have the same mask. - - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> pos_emb = torch.rand(1, 20*2+1, 512) - >>> out = decoder_layer(tgt, pos_emb, memory) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - ) -> None: - super(TransformerDecoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) - self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.dropout3 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = nn.functional.relu - super(TransformerDecoderLayer, self).__setstate__(state) - - def forward( - self, - x: torch.Tensor, - pos_emb: torch.Tensor, - memory: torch.Tensor, - x_mask: Optional[torch.Tensor] = None, - key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Pass the inputs (and mask) through the decoder layer. - - Args: - x - The input embedding, to be added to by the forward function, of shape (T, N, C). - Attention within x will be left-to-right only (causal), thanks to x_mask. - pos_emb: - A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels, - containing the relative positional encoding. - memory: - the sequence from the last layer of the encoder (required). Shape = (T, N, C) - x_mask: - the mask for the x, to enforce causal (left to right) attention (optional). - Shape == (T, T); may be bool or float. The first T pertains to the output, - the second T to the input. - key_padding_mask: - the key-padding mask to use for both the x and memory sequences. Shep == (N, T); - may be bool (True==masked) or float (to be added to attention scores). - - Returns: - Returns 'x plus something', a torch.Tensor with dtype the same as x (e.g. float), - and shape (T, N, C). - """ - residual = x - x = self.norm1(x) - self_attn = self.self_attn(x, x, x, - key_padding_mask=key_padding_mask, - need_weights=False - attn_mask=x_mask, - )[0] - x = residual + self.dropout1(self_attn) - - residual = x - x = self.norm2(x) - src_attn = self.src_attn(x, memory, memory, - key_padding_mask=key_padding_mask, - need_weights=False, - )[0] - x = residual + self.dropout2(src_attn) - - residual = x - x = self.norm3(x) - ff = self.linear2(self.dropout(self.activation(self.linear1(x)))) - x = residual + self.dropout3(ff) - return x - - -def _get_activation_fn(activation: str): - if activation == "relu": - return nn.functional.relu - elif activation == "gelu": - return nn.functional.gelu - - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) - - -class PositionalEncoding(nn.Module): - """This class implements the positional encoding - proposed in the following paper: - - - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf - - PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) - PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) - - Note:: - - 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) - = exp(-1* 2i / d_model * log(100000)) - = exp(2i * -(log(10000) / d_model)) - """ - - def __init__(self, d_model: int, dropout: float = 0.1) -> None: - """ - Args: - d_model: - Embedding dimension. - dropout: - Dropout probability to be applied to the output of this module. - """ - super().__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = nn.Dropout(p=dropout) - self.pe = None - - def extend_pe(self, x: torch.Tensor) -> None: - """Extend the time t in the positional encoding if required. - - The shape of `self.pe` is [1, T1, d_model]. The shape of the input x - is [N, T, d_model]. If T > T1, then we change the shape of self.pe - to [N, T, d_model]. Otherwise, nothing is done. - - Args: - x: - It is a tensor of shape [N, T, C]. - Returns: - Return None. - """ - if self.pe is not None: - if self.pe.size(1) >= x.size(1): - if self.pe.dtype != x.dtype or self.pe.device != x.device: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - # Now pe is of shape [1, T, d_model], where T is x.size(1) - self.pe = pe.to(device=x.device, dtype=x.dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Add positional encoding. - - Args: - x: - Its shape is [N, T, C] - - Returns: - Return a tensor of shape [N, T, C] - """ - self.extend_pe(x) - x = x * self.xscale + self.pe[:, : x.size(1), :] - return self.dropout(x) - - -class Noam(object): - """ - Implements Noam optimizer. - - Proposed in - "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa - - Args: - params: - iterable of parameters to optimize or dicts defining parameter groups - model_size: - attention dimension of the transformer model - factor: - learning rate factor - warm_step: - warmup steps - """ - - def __init__( - self, - params, - model_size: int = 256, - factor: float = 10.0, - warm_step: int = 25000, - weight_decay=0, - ) -> None: - """Construct an Noam object.""" - self.optimizer = torch.optim.Adam( - params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay - ) - self._step = 0 - self.warmup = warm_step - self.factor = factor - self.model_size = model_size - self._rate = 0 - - @property - def param_groups(self): - """Return param_groups.""" - return self.optimizer.param_groups - - def step(self): - """Update parameters and rate.""" - self._step += 1 - rate = self.rate() - for p in self.optimizer.param_groups: - p["lr"] = rate - self._rate = rate - self.optimizer.step() - - def rate(self, step=None): - """Implement `lrate` above.""" - if step is None: - step = self._step - return ( - self.factor - * self.model_size ** (-0.5) - * min(step ** (-0.5), step * self.warmup ** (-1.5)) - ) - - def zero_grad(self): - """Reset gradient.""" - self.optimizer.zero_grad() - - def state_dict(self): - """Return state_dict.""" - return { - "_step": self._step, - "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), - } - - def load_state_dict(self, state_dict): - """Load state_dict.""" - for key, value in state_dict.items(): - if key == "optimizer": - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) - - -class LabelSmoothingLoss(nn.Module): - """ - Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w) - and p_{prob. computed by model}(w) is minimized. - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa - - Args: - size: the number of class - padding_idx: padding_idx: ignored class id - smoothing: smoothing rate (0.0 means the conventional CE) - normalize_length: normalize loss by sequence length if True - criterion: loss function to be smoothed - """ - - def __init__( - self, - size: int, - padding_idx: int = -1, - smoothing: float = 0.1, - normalize_length: bool = False, - criterion: nn.Module = nn.KLDivLoss(reduction="none"), - ) -> None: - """Construct an LabelSmoothingLoss object.""" - super(LabelSmoothingLoss, self).__init__() - self.criterion = criterion - self.padding_idx = padding_idx - assert 0.0 < smoothing <= 1.0 - self.confidence = 1.0 - smoothing - self.smoothing = smoothing - self.size = size - self.true_dist = None - self.normalize_length = normalize_length - - def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Compute loss between x and target. - - Args: - x: - prediction of dimension - (batch_size, input_length, number_of_classes). - target: - target masked with self.padding_id of - dimension (batch_size, input_length). - - Returns: - A scalar tensor containing the loss without normalization. - """ - assert x.size(2) == self.size - # batch_size = x.size(0) - x = x.view(-1, self.size) - target = target.view(-1) - with torch.no_grad(): - true_dist = x.clone() - true_dist.fill_(self.smoothing / (self.size - 1)) - ignore = target == self.padding_idx # (B,) - total = len(target) - ignore.sum().item() - target = target.masked_fill(ignore, 0) # avoid -1 index - true_dist.scatter_(1, target.unsqueeze(1), self.confidence) - kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) - # denom = total if self.normalize_length else batch_size - denom = total if self.normalize_length else 1 - return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom - - - - -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: - """Generate a length mask for input. - - The masked position are filled with True, - Unmasked positions are filled with False. - - Args: - ys_pad: - padded tensor of dimension (batch_size, input_length). - ignore_id: - the ignored number (the padding number) in ys_pad - - Returns: - Tensor: - a bool tensor of the same shape as the input tensor. - """ - ys_mask = ys_pad == ignore_id - return ys_mask - - -def generate_square_subsequent_mask(sz: int, device: torch.Device) -> torch.Tensor: - """Generate a square mask for the sequence. The masked positions are - filled with float('-inf'). Unmasked positions are filled with float(0.0). - The mask can be used for masked self-attention. - - For instance, if sz is 3, it returns:: - - tensor([[0., -inf, -inf], - [0., 0., -inf], - [0., 0., 0]]) - - Args: - sz: mask size - - Returns: - A square mask of dimension (sz, sz) - """ - mask = (torch.triu(torch.ones(sz, sz, device=torch.Device)) == 1).transpose(0, 1) - mask = ( - mask.float() - .masked_fill(mask == 0, float("-inf")) - .masked_fill(mask == 1, float(0.0)) - ) - return mask - - -def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: - """Prepend sos_id to each utterance. - - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - sos_id: - The ID of the SOS token. - - Return: - Return a new list-of-list, where each sublist starts - with SOS ID. - """ - ans = [] - for utt in token_ids: - ans.append([sos_id] + utt) - return ans - - -def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: - """Append eos_id to each utterance. - - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - eos_id: - The ID of the EOS token. - - Return: - Return a new list-of-list, where each sublist ends - with EOS ID. - """ - ans = [] - for utt in token_ids: - ans.append(utt + [eos_id]) - return ans - - - -class MaskedConvolutionModule(nn.Module): - """ - This is used in the MaskedLmConformerLayer. It is the same as the ConvolutionModule - of theConformer code, but with key_padding_mask supported to make the output independent - of the batching. - - Modified, ultimately, from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - """ - - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: - """Construct a MaskedConvolutionModule object.""" - super(MaskedConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.depthwise_conv = nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=(kernel_size - 1) // 2, - groups=channels, - bias=bias, - ) - self.norm = nn.LayerNorm(channels) - self.pointwise_conv2 = nn.Conv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.activation = Swish() - - def forward(self, x: Tensor, key_padding_mask: Optional[Tensor]) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (T, N, C) == (#time, batch, channels). - key_padding_mask: if supplied, a Tensor with dtype=torch.Bool and - shape (N, T), with True for positions that correspond to - padding (and should be zeroed in convolutions). - - Returns: - Tensor: Output tensor (T, N, C) - - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # Logical-not key_padding_mask, unsqueeze to shape (N, 1, T) and convert - # to float. Then we can just multiply by it when we need to apply - # masking, i.e. prior to the convolution over time. - if key_padding_mask is not None: - x = x * torch.logical_not(key_padding_mask).unsqueeze(1).to(dtype=x.dtype) - - # 1D Depthwise Conv - x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1) # (time, batch, channel) - - -class Swish(torch.nn.Module): - """Construct an Swish object.""" - - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x) - - - -class MaskedLmConformerEncoderLayer(nn.Module): - """ - MaskedLmConformerEncoderLayer is made up of self-attn, feedforward and convolution - networks. It's a simplified version of the conformer code we were previously - using, with pre-normalization hard-coded, relative positional encoding, - LayerNorm instead of BatchNorm in the convolution layers, and the key_padding_mask - applied also in the convolution layers. - - See: "Conformer: Convolution-augmented Transformer for Speech Recognition", for - the basic conformer. - - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - - Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - ) -> None: - super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) - - self.feed_forward = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), - ) - - self.feed_forward_macaron = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), - ) - - self.conv_module = MaskedConvolutionModule(d_model, cnn_module_kernel) - - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module - self.norm_ff = nn.LayerNorm(d_model) # for the FNN module - self.norm_mha = nn.LayerNorm(d_model) # for the MHA module - - self.ff_scale = 0.5 - - self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block - - self.dropout = nn.Dropout(dropout) - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - attn_mask: Optional[Tensor] = None, - key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - - Args: - x: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - attn_mask: the mask for the x sequence's attention to itself (optional); - of shape (T, T) - key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - x: (T, N, C) i.e. (seq_len, batch_size, num_channels) - pos_emb: (N, 2*T-1, C) - attn_mask: (T, T) or (N*num_heads, T, T), of dtype torch.bool or torch.float, where - the 1st S is interpreted as the target sequence (output) and the 2nd as the source - sequence (input). - key_padding_mask: (N, T), of dtype torch.bool - - T is the sequence length, N is the batch size, C is the number of channels. - Return: - Returns x with something added to it, of shape (T, N, C) - """ - - # macaron style feed forward module - residual = x - x = self.norm_ff_macaron(x) - x = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(x) - ) - - # multi-headed self-attention module - residual = x - x = self.norm_mha(x) - self_attn = self.self_attn(x, x, x, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - need_weights=False - )[0] - x = residual + self.dropout(self_attn) - - # convolution module - residual = x - x = self.norm_conv(x) - - x = residual + self.dropout(self.conv_module(x, key_padding_mask=key_padding_mask)) - - # feed forward module - residual = x - x = self.norm_ff(x) - x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) - - x = self.norm_final(x) - - return x - - -def _get_clones(module, N): - return ModuleList([copy.deepcopy(module) for i in range(N)]) - -class MaskedLmConformerEncoder(Module): - r"""MaskedLmConformerEncoder is a stack of N encoder layers, modified from - torch.nn.TransformerEncoder - - Args: - encoder_layer: an instance of the MaskedLmConformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - norm: the layer normalization component (optional). - - Examples:: - >>> encoder_layer = MaskedLmConformerEncoderLayer(d_model=512, nhead=8) - >>> conformer_encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> src, pos_emb = self.encoder_pos(src) - >>> out = conformer_encoder(src, pos_emb) - """ - __constants__ = ['norm'] - - def __init__(self, encoder_layer: nn.Module, num_layers: int, - norm: Optional[nn.Module] = None): - super(MaskedLmConformerEncoder, self).__init__() - self.layers = _get_clones(encoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - attn_mask: Optional[Tensor] = None, - key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - Args - x: input of shape (T, N, C), i.e. (seq_len, batch, channels) - pos_emb: positional embedding tensor of shape (N, 2*T-1, C), - attn_mask (optional, likely not used): mask for self-attention of - x to itself, of shape (T, T) - key_padding_mask (optional): mask of shape (N, T), dtype must be bool. - Returns: - Returns a tensor with the same shape as x, i.e. (T, N, C). - """ - for mod in self.layers: - x = mod( - x - pos_emb, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - ) - - if self.norm is not None: - x = self.norm(x) - - return x - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. - - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: - """Construct an PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() - self.d_model = d_model - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - - def extend_pe(self, x: Tensor) -> None: - """Reset the positional encodings.""" - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vector and `j` means the - # position of key vector. We use position relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, time, `*`). - - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - torch.Tensor: Encoded tensor (1, 2*time-1, `*`). - - """ - self.extend_pe(x) - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x.size(1) - + 1 : self.pe.size(1) // 2 - + x.size(1), - ] - return self.dropout(x), self.dropout(pos_emb) - - -class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding - - See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - - Args: - embed_dim: total dimension of the model. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - - Examples:: - - >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - ) -> None: - super(RelPositionMultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" - - self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) - - # linear transformation for positional encoding. - self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - - self._reset_parameters() - - def _reset_parameters(self) -> None: - nn.init.xavier_uniform_(self.in_proj.weight) - nn.init.constant_(self.in_proj.bias, 0.0) - nn.init.constant_(self.out_proj.bias, 0.0) - - nn.init.xavier_uniform_(self.pos_bias_u) - nn.init.xavier_uniform_(self.pos_bias_v) - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - need_weights: if true, return (output, attn_output_weights); else, (output, None). - - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - - Inputs: - - query: :math:`(T, N, C)` where T is the output sequence length, N is the batch size, C is - the embedding dimension (number of channels). - - key: :math:`(S, N, C)`, where S is the input sequence length. - - value: :math:`(S, N, C)` - - pos_emb: :math:`(N, 2*T-1, C)`. Note: this assumes T == S, which it will be, but - still we use different letters because S relates to the input position, T to the - output posision. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the input sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(T, S)` where T is the output sequence length, S is the input sequence length. - 3D mask :math:`(N*num_heads, T, S)` where N is the batch size, where T is the output sequence length, - S is the input sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Return: - (output, attn_output_weights) if need_weights==True, else (output, None), where: - - - output: :math:`(T, N, C)` where T is the output sequence length, N is the batch size, - C is the embedding/channel dimension. - - attn_output_weights: :math:`(N, T, S)` where N is the batch size, - T is the output sequence length, S is the input sequence length. - """ - return self.multi_head_attention_forward( - query, - key, - value, - pos_emb, - self.embed_dim, - self.num_heads, - self.in_proj.weight, - self.in_proj.bias, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - ) - - def rel_shift(self, x: Tensor) -> Tensor: - """Compute relative positional encoding. - - Args: - x: Input tensor (batch, head, time1, 2*time1-1). - time1 means the length of query vector. - - Returns: - Tensor: tensor of shape (batch, head, time1, time2) - (note: time2 has the same value as time1, but it is for - the key, while time1 is for the query). - """ - (batch_size, num_heads, time1, n) = x.shape - assert n == 2 * time1 - 1 - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time1), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) - - def multi_head_attention_forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - embed_dim_to_check: total dimension of the model. - num_heads: parallel attention heads. - in_proj_weight, in_proj_bias: input projection weight and bias. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence - length, N is the batch size, E is the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - - tgt_len, bsz, embed_dim = query.size() - assert embed_dim == embed_dim_to_check - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) - - head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), "embed_dim must be divisible by num_heads" - scaling = float(head_dim) ** -0.5 - - if torch.equal(query, key) and torch.equal(key, value): - # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) - - elif torch.equal(key, value): - # encoder-decoder attention - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) - - else: - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = embed_dim * 2 - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim * 2 - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) - - #if not self.is_espnet_structure: - # q = q * scaling - - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - query.size(0), - key.size(0), - ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) - else: - raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) - k = k.contiguous().view(-1, bsz, num_heads, head_dim) - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) - - src_len = k.size(0) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == src_len, "{} == {}".format( - key_padding_mask.size(1), src_len - ) - - q = q.transpose(0, 1) # (batch, time1, head, d_k) - - pos_emb_bsz = pos_emb.size(0) - assert pos_emb_bsz in (1, bsz) # actually it is 1 - p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - - q_with_bias_u = (q + self.pos_bias_u).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - q_with_bias_v = (q + self.pos_bias_v).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - # compute attention score - # first compute matrix a and matrix c - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) - - # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p.transpose(-2, -1) - ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd) - - #if not self.is_espnet_structure: - # attn_output_weights = ( - # matrix_ac + matrix_bd - # ) # (batch, head, time1, time2) - #else: - - attn_output_weights = ( - matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) - - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) - - assert list(attn_output_weights.size()) == [ - bsz * num_heads, - tgt_len, - src_len, - ] - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float("-inf")) - else: - attn_output_weights += attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len - ) - - attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training - ) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias - ) - - if need_weights: - # average attention weights over heads - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - return attn_output, attn_output_weights.sum(dim=1) / num_heads - else: - return attn_output, None