#!/usr/bin/env python3 # Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from encoder_interface import EncoderInterface import math import warnings from typing import Optional, Tuple, Sequence from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d import torch from torch import Tensor, nn from icefall.utils import make_pad_mask class Conformer(EncoderInterface): """ Args: num_features (int): Number of input features output_dim (int): Number of output dimension subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension nhead (int): number of head dim_feedforward (int): feedforward dimention num_encoder_layers (int): number of encoder layers dropout (float): dropout rate cnn_module_kernel (int): Kernel size of convolution module vgg_frontend (bool): whether to use vgg frontend. """ def __init__( self, num_features: int, output_dim: int, subsampling_factor: int = 4, d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048, num_encoder_layers: int = 12, dropout: float = 0.1, cnn_module_kernel: int = 31, aux_layer_period: int = 3 ) -> None: super(Conformer, self).__init__() self.num_features = num_features self.output_dim = output_dim self.subsampling_factor = subsampling_factor if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model self.encoder_embed = Conv2dSubsampling(num_features, 128, d_model) self.encoder_pos = RelPositionalEncoding(d_model, dropout) encoder_layer = ConformerEncoderLayer( d_model, nhead, dim_feedforward, dropout, cnn_module_kernel, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) self.encoder_output_layer = nn.Sequential( nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim) ) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x: The input tensor. Its shape is (batch_size, seq_len, feature_dim). x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. warmup: A floating point value that gradually increases from 0 throughout training; when it is >= 1.0 we are "fully warmed up". It is used to turn modules on sequentially. Returns: Return a tuple containing 2 tensors: - logits, its shape is (batch_size, output_seq_len, output_dim) - logit_lens, a tensor of shape (batch_size,) containing the number of frames in `logits` before padding. """ x = self.encoder_embed(x) x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # Caution: We assume the subsampling factor is 4! lengths = ((x_lens - 1) // 2 - 1) // 2 assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) x = self.encoder(x, pos_emb, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) return logits, lengths class ConformerEncoderLayer(nn.Module): """ ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. See: "Conformer: Convolution-augmented Transformer for Speech Recognition" Args: d_model: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). cnn_module_kernel (int): Kernel size of convolution module. Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) >>> src = torch.rand(10, 32, 512) >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, cnn_module_kernel: int = 31, ) -> None: super(ConformerEncoderLayer, self).__init__() self.d_model = d_model self.self_attn = RelPositionMultiheadAttention( d_model, nhead, dropout=0.0 ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), ActivationBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) self.feed_forward_macaron = nn.Sequential( ScaledLinear(d_model, dim_feedforward), ActivationBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.norm_final = BasicNorm(d_model) # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0) self.dropout = nn.Dropout(dropout) def forward( self, src: Tensor, pos_emb: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, ) -> Tensor: """ Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). pos_emb: Positional embedding tensor (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). warmup: controls selective activation of layers; if < 0.5, it's possible that not all modules will be included. Actually we add the feed_forward_macaron and self_attn modules at warmup=0.0 and the conv_module and feed_forward at warmup=0.5. Shape: src: (S, N, E). pos_emb: (N, 2*S-1, E) src_mask: (S, S). src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number """ src_orig = src # when warmup == 0.0, alpha is always 0.1, but it gradually changes to # always being 1.0 when warmup equals 1.0. The reason for using 0.1 and not # 0.0 is that it gives us a gradient so we can learn something when we are turned # off. # # min(0.1, warmup) # is used in place of warmup to ensure that even at the start of the warm-up # period we sometimes use scale 1.0; this ensures that the modules do not # compensate for the small scale by just producing larger output. warmup = max(warmup, 0.1) if self.training: warmup = min(warmup, 0.95) # effectively, layer-drop with 1-in-20 prob. alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1 # macaron style feed forward module src = src + self.dropout(self.feed_forward_macaron(src)) # multi-headed self-attention module src_att = self.self_attn( src, src, src, pos_emb=pos_emb, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] src = src + self.dropout(src_att) # convolution module src = src + self.dropout(self.conv_module(src)) # feed forward module src = src + self.dropout(self.feed_forward(src)) src = self.norm_final(self.balancer(src)) if alpha != 1.0: src = alpha * src + (1-alpha) * src_orig return src class ConformerEncoder(nn.Module): r"""ConformerEncoder is a stack of N encoder layers Args: encoder_layer: an instance of the ConformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) >>> pos_emb = torch.rand(32, 19, 512) >>> out = conformer_encoder(src, pos_emb) """ def __init__(self, encoder_layer: nn.Module, num_layers: int, aux_layers: Sequence[int]) -> None: super().__init__() self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) self.aux_layers = set(aux_layers + [num_layers - 1]) assert num_layers - 1 not in aux_layers self.num_layers = num_layers num_channels = encoder_layer.d_model def forward( self, src: Tensor, pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0 ) -> Tensor: r"""Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). pos_emb: Positional embedding tensor (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: src: (S, N, E). pos_emb: (N, 2*S-1, E) mask: (S, S). src_key_padding_mask: (N, S). S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number """ output = src num_layers = len(self.layers) for i, mod in enumerate(self.layers): output = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, warmup=warmup, ) return output class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py Args: d_model: Embedding dimension. dropout_rate: Dropout rate. max_len: Maximum input length. """ def __init__( self, d_model: int, dropout_rate: float, max_len: int = 5000 ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model self.dropout = torch.nn.Dropout(p=dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) def extend_pe(self, x: Tensor) -> None: """Reset the positional encodings.""" if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device if self.pe.dtype != x.dtype or str(self.pe.device) != str( x.device ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the # position of key vector. We use position relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: """Add positional encoding. Args: x (torch.Tensor): Input tensor (batch, time, `*`). Returns: torch.Tensor: Encoded tensor (batch, time, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). """ self.extend_pe(x) pos_emb = self.pe[ :, self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 # noqa E203 + x.size(1), ] return self.dropout(x), self.dropout(pos_emb) class RelPositionMultiheadAttention(nn.Module): r"""Multi-Head Attention layer with relative position encoding See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Args: embed_dim: total dimension of the model. num_heads: parallel attention heads. dropout: a Dropout layer on attn_output_weights. Default: 0.0. Examples:: >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) """ def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.0, ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) # linear transformation for positional encoding. self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) # these two learnable bias are used in matrix c and matrix d # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) self._reset_parameters() def _pos_bias_u(self): return self.pos_bias_u * self.pos_bias_u_scale.exp() def _pos_bias_v(self): return self.pos_bias_v * self.pos_bias_v_scale.exp() def _reset_parameters(self) -> None: nn.init.normal_(self.pos_bias_u, std=0.01) nn.init.normal_(self.pos_bias_v, std=0.01) def forward( self, query: Tensor, key: Tensor, value: Tensor, pos_emb: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. pos_emb: Positional embedding tensor key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. Shape: - Inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """ return self.multi_head_attention_forward( query, key, value, pos_emb, self.embed_dim, self.num_heads, self.in_proj.get_weight(), self.in_proj.get_bias(), self.dropout, self.out_proj.get_weight(), self.out_proj.get_bias(), training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, ) def rel_shift(self, x: Tensor) -> Tensor: """Compute relative positional encoding. Args: x: Input tensor (batch, head, time1, 2*time1-1). time1 means the length of query vector. Returns: Tensor: tensor of shape (batch, head, time1, time2) (note: time2 has the same value as time1, but it is for the key, while time1 is for the query). """ (batch_size, num_heads, time1, n) = x.shape assert n == 2 * time1 - 1 # Note: TorchScript requires explicit arg for stride() batch_stride = x.stride(0) head_stride = x.stride(1) time1_stride = x.stride(2) n_stride = x.stride(3) return x.as_strided( (batch_size, num_heads, time1, time1), (batch_stride, head_stride, time1_stride - n_stride, n_stride), storage_offset=n_stride * (time1 - 1), ) def multi_head_attention_forward( self, query: Tensor, key: Tensor, value: Tensor, pos_emb: Tensor, embed_dim_to_check: int, num_heads: int, in_proj_weight: Tensor, in_proj_bias: Tensor, dropout_p: float, out_proj_weight: Tensor, out_proj_bias: Tensor, training: bool = True, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. pos_emb: Positional embedding tensor embed_dim_to_check: total dimension of the model. num_heads: parallel attention heads. in_proj_weight, in_proj_bias: input projection weight and bias. dropout_p: probability of an element to be zeroed. out_proj_weight, out_proj_bias: the output projection weight and bias. training: apply dropout if is ``True``. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. This is an binary mask. When the value is True, the corresponding value on the attention layer will be filled with -inf. need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. Shape: Inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """ tgt_len, bsz, embed_dim = query.size() assert embed_dim == embed_dim_to_check assert key.size(0) == value.size(0) and key.size(1) == value.size(1) head_dim = embed_dim // num_heads assert ( head_dim * num_heads == embed_dim ), "embed_dim must be divisible by num_heads" scaling = float(head_dim) ** -0.5 if torch.equal(query, key) and torch.equal(key, value): # self-attention q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = 0 _end = embed_dim _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim _end = None _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) else: # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = 0 _end = embed_dim _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim _end = embed_dim * 2 _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] k = nn.functional.linear(key, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim * 2 _end = None _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] v = nn.functional.linear(value, _w, _b) if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( attn_mask.dtype ) if attn_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for attn_mask is deprecated. Use bool tensor instead." ) attn_mask = attn_mask.to(torch.bool) if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: raise RuntimeError( "The size of the 2D attn_mask is not correct." ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: raise RuntimeError( "The size of the 3D attn_mask is not correct." ) else: raise RuntimeError( "attn_mask's dimension {} is not supported".format( attn_mask.dim() ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool if ( key_padding_mask is not None and key_padding_mask.dtype == torch.uint8 ): warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) k = k.contiguous().view(-1, bsz, num_heads, head_dim) v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) src_len = k.size(0) if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz ) assert key_padding_mask.size(1) == src_len, "{} == {}".format( key_padding_mask.size(1), src_len ) q = q.transpose(0, 1) # (batch, time1, head, d_k) pos_emb_bsz = pos_emb.size(0) assert pos_emb_bsz in (1, bsz) # actually it is 1 p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) q_with_bias_u = (q + self._pos_bias_u()).transpose( 1, 2 ) # (batch, head, time1, d_k) q_with_bias_v = (q + self._pos_bias_v()).transpose( 1, 2 ) # (batch, head, time1, d_k) # compute attention score # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) matrix_ac = torch.matmul( q_with_bias_u, k ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( q_with_bias_v, p.transpose(-2, -1) ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) attn_output_weights = ( matrix_ac + matrix_bd ) # (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, -1 ) assert list(attn_output_weights.size()) == [ bsz * num_heads, tgt_len, src_len, ] if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_output_weights.masked_fill_(attn_mask, float("-inf")) else: attn_output_weights += attn_mask if key_padding_mask is not None: attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), ) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) attn_output_weights = nn.functional.dropout( attn_output_weights, p=dropout_p, training=training ) attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( attn_output.transpose(0, 1) .contiguous() .view(tgt_len, bsz, embed_dim) ) attn_output = nn.functional.linear( attn_output, out_proj_weight, out_proj_bias ) if need_weights: # average attention weights over heads attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) return attn_output, attn_output_weights.sum(dim=1) / num_heads else: return attn_output, None class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model. Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py Args: channels (int): The number of channels of conv layers. kernel_size (int): Kernerl size of conv layers. bias (bool): Whether to use bias in conv layers (default=True). """ def __init__( self, channels: int, kernel_size: int, bias: bool = True ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 self.pointwise_conv1 = ScaledConv1d( channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, ) # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, # but sometimes, for some reason, for layer 0 the rms ends up being very large, # between 50 and 100 for different channels. This will cause very peaky and # sparse derivatives for the sigmoid gating function, which will tend to make # the loss function not learn effectively. (for most layers the average absolute values # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different # layers, which likely breaks down as 0.5 for the "linear" half and # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # it will be in a better position to start learning something, i.e. to latch onto # the correct range. self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0) self.depthwise_conv = ScaledConv1d( channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, ) self.deriv_balancer2 = ActivationBalancer(channel_dim=1, min_positive=0.05, max_positive=1.0) self.activation = DoubleSwish() self.pointwise_conv2 = ScaledConv1d( channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, initial_scale=0.25 ) def forward(self, x: Tensor) -> Tensor: """Compute convolution module. Args: x: Input tensor (#time, batch, channels). Returns: Tensor: Output tensor (#time, batch, channels). """ # exchange the temporal dimension and the feature dimension x = x.permute(1, 2, 0) # (#batch, channels, time). # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2*channels, time) x = self.deriv_balancer1(x) x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv x = self.depthwise_conv(x) x = self.deriv_balancer2(x) x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) return x.permute(2, 0, 1) class Identity(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: return x class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/4 length). Convert an input of shape (N, T, idim) to an output with shape (N, T', odim), where T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 It is based on https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa """ def __init__(self, in_channels: int, out_channels: int, layer1_channels: int = 64, layer2_channels: int = 128) -> None: """ Args: in_channels: Number of channels in. The input shape is (N, T, in_channels). Caution: It requires: T >=7, in_channels >=7 out_channels Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) layer1_channels: Number of channels in layer1 layer1_channels: Number of channels in layer2 """ assert in_channels >= 7 super().__init__() self.conv = nn.Sequential( ScaledConv2d( in_channels=1, out_channels=layer1_channels, kernel_size=3, stride=2 ), ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( in_channels=layer1_channels, out_channels=layer2_channels, kernel_size=3, stride=2 ), ActivationBalancer(channel_dim=1), DoubleSwish(), ) self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) # set learn_eps=False because out_norm is preceded by `out`, and `out` # itself has learned scale, so the extra degree of freedom is not # needed. self.out_norm = BasicNorm(out_channels, learn_eps=False) # constrain median of output to be close to zero. self.out_balancer = ActivationBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. Args: x: Its shape is (N, T, idim). Returns: Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) """ # On entry, x is (N, T, idim) x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = self.conv(x) # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) x = self.out_norm(x) x = self.out_balancer(x) return x if __name__ == '__main__': feature_dim = 50 c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) batch_size = 5 seq_len = 20 # Just make sure the forward pass runs. f = c(torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), warmup=0.5)