diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/conformer.py deleted file mode 120000 index 70a7ddf11..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless/conformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/conformer.py new file mode 100644 index 000000000..3e88fe4ef --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/conformer.py @@ -0,0 +1,913 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import warnings +from typing import Optional, Tuple + +import torch +from torch import Tensor, nn +from transformer import Transformer + +from icefall.utils import make_pad_mask + + +class Conformer(Transformer): + """ + Args: + num_features (int): Number of input features + subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) + d_model (int): attention dimension + nhead (int): number of head + dim_feedforward (int): feedforward dimention + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + cnn_module_kernel (int): Kernel size of convolution module + normalize_before (bool): whether to use layer_norm before the first block. + vgg_frontend (bool): whether to use vgg frontend. + """ + + def __init__( + self, + num_features: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + vgg_frontend: bool = False, + ) -> None: + super(Conformer, self).__init__( + num_features=num_features, + subsampling_factor=subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + dropout=dropout, + normalize_before=normalize_before, + vgg_frontend=vgg_frontend, + ) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + cnn_module_kernel, + normalize_before, + ) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.normalize_before = normalize_before + if self.normalize_before: + self.after_norm = nn.LayerNorm(d_model) + else: + # Note: TorchScript detects that self.after_norm could be used inside forward() + # and throws an error without this change. + self.after_norm = identity + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + Returns: + Return a tuple containing 2 tensors: + - logits, its shape is (batch_size, output_seq_len, d_model) + - logit_lens, a tensor of shape (batch_size,) containing the number + of frames in `logits` before padding. + """ + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == lengths.max().item() + mask = make_pad_mask(lengths) + + x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) + + if self.normalize_before: + x = self.after_norm(x) + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return x, lengths + + +class ConformerEncoderLayer(nn.Module): + """ + ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + normalize_before: whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + + self.ff_scale = 0.5 + + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block + + self.dropout = nn.Dropout(dropout) + + self.normalize_before = normalize_before + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + # macaron style feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) + if not self.normalize_before: + src = self.norm_ff_macaron(src) + + # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) + + # convolution module + residual = src + if self.normalize_before: + src = self.norm_conv(src) + src = residual + self.dropout(self.conv_module(src)) + if not self.normalize_before: + src = self.norm_conv(src) + + # feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) + + if self.normalize_before: + src = self.norm_final(src) + + return src + + +class ConformerEncoder(nn.Module): + r"""ConformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = conformer_encoder(src, pos_emb) + """ + + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for mod in self.layers: + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.weight, + self.in_proj.bias, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + ) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + assert n == 2 * time1 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time1), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + q_with_bias_u = (q + self.pos_bias_u).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self.pos_bias_v).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p.transpose(-2, -1) + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd) + + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.norm = nn.LayerNorm(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward(self, x: Tensor) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + # x is (batch, channels, time) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) + + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def identity(x): + return x diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/test_conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/test_conformer.py new file mode 100755 index 000000000..6b58c8bb6 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/test_conformer.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python3 ./pruned_transducer_stateless_multi_datasets/test_conformer.py +""" + +import torch +from conformer import Conformer + + +def test_conformer(): + d_model = 512 + conformer = Conformer( + num_features=80, + subsampling_factor=4, + d_model=d_model, + nhead=8, + dim_feedforward=2048, + num_encoder_layers=12, + ) + N = 3 + T = 100 + C = 80 + x = torch.randn(N, T, C) + x_lens = torch.tensor([50, 100, 80]) + logits, logit_lens = conformer(x, x_lens) + + expected_T = ((T - 1) // 2 - 1) // 2 + assert logits.shape == (N, expected_T, d_model) + assert logit_lens.max().item() == expected_T + print(logits.shape) + print(logit_lens) + + +def main(): + test_conformer() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/test_transformer.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/test_transformer.py new file mode 100755 index 000000000..4cdb11e21 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/test_transformer.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python3 ./pruned_transducer_stateless_multi_datasets/test_transformer.py +""" + +import torch +from transformer import Transformer + + +def test_transformer(): + d_model = 512 + transformer = Transformer( + num_features=80, + subsampling_factor=4, + d_model=d_model, + nhead=8, + dim_feedforward=2048, + num_encoder_layers=12, + ) + N = 3 + T = 100 + C = 80 + x = torch.randn(N, T, C) + x_lens = torch.tensor([50, 100, 80]) + logits, logit_lens = transformer(x, x_lens) + + expected_T = ((T - 1) // 2 - 1) // 2 + assert logits.shape == (N, expected_T, d_model) + assert logit_lens.max().item() == expected_T + print(logits.shape) + print(logit_lens) + + +def main(): + test_transformer() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py index 8755ba8df..425879321 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py @@ -288,7 +288,6 @@ def get_params() -> AttributeDict: "log_diagnostics": False, # parameters for conformer "feature_dim": 80, - "encoder_out_dim": 512, "subsampling_factor": 4, "attention_dim": 512, "nhead": 8, @@ -308,7 +307,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, - output_dim=params.encoder_out_dim, subsampling_factor=params.subsampling_factor, d_model=params.attention_dim, nhead=params.nhead, @@ -320,12 +318,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module: - # Note: We set the embedding_dim of the decoder to - # vocab_size so that its output can be added with - # that of the encoder decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, + embedding_dim=params.attention_dim, blank_id=params.blank_id, context_size=params.context_size, ) @@ -334,7 +329,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - input_dim=params.encoder_out_dim, + input_dim=params.attention_dim, output_dim=params.vocab_size, ) return joiner diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/transformer.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/transformer.py deleted file mode 120000 index 2561f65c5..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/transformer.py +++ /dev/null @@ -1 +0,0 @@ -../transducer/transformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/transformer.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/transformer.py new file mode 100644 index 000000000..bce92f8f6 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/transformer.py @@ -0,0 +1,408 @@ +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from subsampling import Conv2dSubsampling, VggSubsampling + +from icefall.utils import make_pad_mask + + +class Transformer(EncoderInterface): + def __init__( + self, + num_features: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + normalize_before: bool = True, + vgg_frontend: bool = False, + ) -> None: + """ + Args: + num_features: + The input dimension of the model. + subsampling_factor: + Number of output frames is num_in_frames // subsampling_factor. + Currently, subsampling_factor MUST be 4. + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder. + num_encoder_layers: + Number of encoder layers. + dropout: + Dropout in encoder. + normalize_before: + If True, use pre-layer norm; False to use post-layer norm. + vgg_frontend: + True to use vgg style frontend for subsampling. + """ + super().__init__() + + self.num_features = num_features + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + if vgg_frontend: + self.encoder_embed = VggSubsampling(num_features, d_model) + else: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_pos = PositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + encoder_norm = nn.LayerNorm(d_model) + else: + encoder_norm = None + + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + norm=encoder_norm, + ) + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + Returns: + Return a tuple containing 2 tensors: + - logits, its shape is (batch_size, output_seq_len, d_model) + - logit_lens, a tensor of shape (batch_size,) containing the number + of frames in `logits` before padding. + """ + x = self.encoder_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == lengths.max().item() + + mask = make_pad_mask(lengths) + x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return x, lengths + + +class TransformerEncoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerEncoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + normalize_before: + whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerEncoderLayer, self).__setstate__(state) + + def forward( + self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional) + + Shape: + src: (S, N, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = src + if self.normalize_before: + src = self.norm1(src) + src2 = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout1(src2) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src2) + if not self.normalize_before: + src = self.norm2(src) + return src + + +def _get_activation_fn(activation: str): + if activation == "relu": + return nn.functional.relu + elif activation == "gelu": + return nn.functional.gelu + + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) + + +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note:: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) + """ + + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. + """ + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + # not doing: self.pe = None because of errors thrown by torchscript + self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) + + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape (N, T, C). + Returns: + Return None. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + # Now pe is of shape (1, T, d_model), where T is x.size(1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding. + + Args: + x: + Its shape is (N, T, C) + + Returns: + Return a tensor of shape (N, T, C) + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1), :] + return self.dropout(x) + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value)