# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Optional, Tuple import torch import torch.nn as nn from encoder_interface import EncoderInterface from subsampling import Conv2dSubsampling, VggSubsampling, ScaledLinear from icefall.utils import make_pad_mask class Transformer(EncoderInterface): def __init__( self, num_features: int, output_dim: int, subsampling_factor: int = 4, d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048, num_encoder_layers: int = 12, dropout: float = 0.1, normalize_before: bool = True, vgg_frontend: bool = False, ) -> None: """ Args: num_features: The input dimension of the model. output_dim: The output dimension of the model. subsampling_factor: Number of output frames is num_in_frames // subsampling_factor. Currently, subsampling_factor MUST be 4. d_model: Attention dimension. nhead: Number of heads in multi-head attention. Must satisfy d_model // nhead == 0. dim_feedforward: The output dimension of the feedforward layers in encoder. num_encoder_layers: Number of encoder layers. dropout: Dropout in encoder. normalize_before: If True, use pre-layer norm; False to use post-layer norm. vgg_frontend: True to use vgg style frontend for subsampling. """ super().__init__() self.num_features = num_features self.output_dim = output_dim self.subsampling_factor = subsampling_factor if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model if vgg_frontend: self.encoder_embed = VggSubsampling(num_features, d_model) else: self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_pos = PositionalEncoding(d_model, dropout) encoder_layer = TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, normalize_before=normalize_before, ) if normalize_before: encoder_norm = nn.LayerNorm(d_model) else: encoder_norm = None self.encoder = nn.TransformerEncoder( encoder_layer=encoder_layer, num_layers=num_encoder_layers, norm=encoder_norm, ) # TODO(fangjun): remove dropout self.encoder_output_layer = nn.Sequential( nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim) ) def forward( self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x: The input tensor. Its shape is (batch_size, seq_len, feature_dim). x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. Returns: Return a tuple containing 2 tensors: - logits, its shape is (batch_size, output_seq_len, output_dim) - logit_lens, a tensor of shape (batch_size,) containing the number of frames in `logits` before padding. """ x = self.encoder_embed(x) x = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # Caution: We assume the subsampling factor is 4! lengths = ((x_lens - 1) // 2 - 1) // 2 assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) return logits, lengths class TransformerEncoderLayer(nn.Module): """ Modified from torch.nn.TransformerEncoderLayer. Add support of normalize_before, i.e., use layer_norm before the first block. Args: d_model: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). activation: the activation function of intermediate layer, relu or gelu (default=relu). normalize_before: whether to use layer_norm before the first block. Examples:: >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) >>> src = torch.rand(10, 32, 512) >>> out = encoder_layer(src) """ def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: str = "relu", normalize_before: bool = True, ) -> None: super(TransformerEncoderLayer, self).__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before def __setstate__(self, state): if "activation" not in state: state["activation"] = nn.functional.relu super(TransformerEncoderLayer, self).__setstate__(state) def forward( self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional) Shape: src: (S, N, E). src_mask: (S, S). src_key_padding_mask: (N, S). S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number """ residual = src if self.normalize_before: src = self.norm1(src) src2 = self.self_attn( src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] src = residual + self.dropout1(src2) if not self.normalize_before: src = self.norm1(src) residual = src if self.normalize_before: src = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = residual + self.dropout2(src2) if not self.normalize_before: src = self.norm2(src) return src def _get_activation_fn(activation: str): if activation == "relu": return nn.functional.relu elif activation == "gelu": return nn.functional.gelu raise RuntimeError( "activation should be relu/gelu, not {}".format(activation) ) class PositionalEncoding(nn.Module): """This class implements the positional encoding proposed in the following paper: - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) Note:: 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) = exp(-1* 2i / d_model * log(100000)) = exp(2i * -(log(10000) / d_model)) """ def __init__(self, d_model: int, dropout: float = 0.1) -> None: """ Args: d_model: Embedding dimension. dropout: Dropout probability to be applied to the output of this module. """ super().__init__() self.d_model = d_model self.xscale = math.sqrt(self.d_model) self.dropout = nn.Dropout(p=dropout) # not doing: self.pe = None because of errors thrown by torchscript self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) def extend_pe(self, x: torch.Tensor) -> None: """Extend the time t in the positional encoding if required. The shape of `self.pe` is (1, T1, d_model). The shape of the input x is (N, T, d_model). If T > T1, then we change the shape of self.pe to (N, T, d_model). Otherwise, nothing is done. Args: x: It is a tensor of shape (N, T, C). Returns: Return None. """ if self.pe is not None: if self.pe.size(1) >= x.size(1): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # Now pe is of shape (1, T, d_model), where T is x.size(1) self.pe = pe.to(device=x.device, dtype=x.dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Add positional encoding. Args: x: Its shape is (N, T, C) Returns: Return a tensor of shape (N, T, C) """ self.extend_pe(x) x = x * self.xscale + self.pe[:, : x.size(1), :] return self.dropout(x) class Noam(object): """ Implements Noam optimizer. Proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa Args: params: iterable of parameters to optimize or dicts defining parameter groups model_size: attention dimension of the transformer model factor: learning rate factor warm_step: warmup steps """ def __init__( self, params, model_size: int = 256, factor: float = 10.0, warm_step: int = 25000, weight_decay=0, ) -> None: """Construct an Noam object.""" self.optimizer = torch.optim.Adam( params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay ) self._step = 0 self.warmup = warm_step self.factor = factor self.model_size = model_size self._rate = 0 @property def param_groups(self): """Return param_groups.""" return self.optimizer.param_groups def step(self): """Update parameters and rate.""" self._step += 1 rate = self.rate() for p in self.optimizer.param_groups: p["lr"] = rate self._rate = rate self.optimizer.step() def rate(self, step=None): """Implement `lrate` above.""" if step is None: step = self._step return ( self.factor * self.model_size ** (-0.5) * self.warmup ** (-0.5 - -0.333) * min(step ** (-0.333), step * self.warmup ** (-1.333)) ) def zero_grad(self): """Reset gradient.""" self.optimizer.zero_grad() def state_dict(self): """Return state_dict.""" return { "_step": self._step, "warmup": self.warmup, "factor": self.factor, "model_size": self.model_size, "_rate": self._rate, "optimizer": self.optimizer.state_dict(), } def load_state_dict(self, state_dict): """Load state_dict.""" for key, value in state_dict.items(): if key == "optimizer": self.optimizer.load_state_dict(state_dict["optimizer"]) else: setattr(self, key, value)