From 4ccae509d3adf364e095573d6f944276c8868815 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 26 Jul 2021 20:06:58 +0800 Subject: [PATCH] WIP: Begin to add BPE decoding --- .gitignore | 2 + egs/librispeech/ASR/conformer_ctc/__init__.py | 0 .../ASR/conformer_ctc/conformer.py | 914 ++++++++++++++ egs/librispeech/ASR/conformer_ctc/decode.py | 129 ++ .../ASR/conformer_ctc/transformer.py | 1100 +++++++++++++++++ egs/librispeech/ASR/local/__init__.py | 0 egs/librispeech/ASR/local/compile_hlg.py | 61 +- egs/librispeech/ASR/local/prepare_lang.py | 9 +- egs/librispeech/ASR/local/prepare_lang_bpe.py | 196 +++ egs/librispeech/ASR/prepare.sh | 20 +- requirements.txt | 3 + 11 files changed, 2416 insertions(+), 18 deletions(-) create mode 100644 egs/librispeech/ASR/conformer_ctc/__init__.py create mode 100644 egs/librispeech/ASR/conformer_ctc/conformer.py create mode 100755 egs/librispeech/ASR/conformer_ctc/decode.py create mode 100644 egs/librispeech/ASR/conformer_ctc/transformer.py create mode 100644 egs/librispeech/ASR/local/__init__.py create mode 100755 egs/librispeech/ASR/local/prepare_lang_bpe.py create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore index b932c9080..6cb9f2299 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ data __pycache__ path.sh exp +exp*/ +*.pt diff --git a/egs/librispeech/ASR/conformer_ctc/__init__.py b/egs/librispeech/ASR/conformer_ctc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py new file mode 100644 index 000000000..1e82eff2f --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -0,0 +1,914 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Apache 2.0 + +import math +import warnings +from typing import Optional, Tuple + +import torch +from torch import Tensor, nn +from transformer import Supervisions, Transformer, encoder_padding_mask + + +class Conformer(Transformer): + """ + Args: + num_features (int): Number of input features + num_classes (int): Number of output classes + subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) + d_model (int): attention dimension + nhead (int): number of head + dim_feedforward (int): feedforward dimention + num_encoder_layers (int): number of encoder layers + num_decoder_layers (int): number of decoder layers + dropout (float): dropout rate + cnn_module_kernel (int): Kernel size of convolution module + normalize_before (bool): whether to use layer_norm before the first block. + vgg_frontend (bool): whether to use vgg frontend. + """ + + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + vgg_frontend: bool = False, + is_espnet_structure: bool = False, + mmi_loss: bool = True, + use_feat_batchnorm: bool = False, + ) -> None: + super(Conformer, self).__init__( + num_features=num_features, + num_classes=num_classes, + subsampling_factor=subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dropout=dropout, + normalize_before=normalize_before, + vgg_frontend=vgg_frontend, + mmi_loss=mmi_loss, + use_feat_batchnorm=use_feat_batchnorm, + ) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + cnn_module_kernel, + normalize_before, + is_espnet_structure, + ) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.normalize_before = normalize_before + self.is_espnet_structure = is_espnet_structure + if self.normalize_before and self.is_espnet_structure: + self.after_norm = nn.LayerNorm(d_model) + else: + # Note: TorchScript detects that self.after_norm could be used inside forward() + # and throws an error without this change. + self.after_norm = identity + + def encode( + self, x: Tensor, supervisions: Optional[Supervisions] = None + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x: Tensor of dimension (batch_size, num_features, input_length). + supervisions : Supervison in lhotse format, i.e., batch['supervisions'] + + Returns: + Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). + Tensor: Mask tensor of dimension (batch_size, input_length) + """ + x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F) + + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + mask = encoder_padding_mask(x.size(0), supervisions) + if mask is not None: + mask = mask.to(x.device) + x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + + if self.normalize_before and self.is_espnet_structure: + x = self.after_norm(x) + + return x, mask + + +class ConformerEncoderLayer(nn.Module): + """ + ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + normalize_before: whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + is_espnet_structure: bool = False, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure + ) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + + self.ff_scale = 0.5 + + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block + + self.dropout = nn.Dropout(dropout) + + self.normalize_before = normalize_before + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + # macaron style feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) + if not self.normalize_before: + src = self.norm_ff_macaron(src) + + # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) + + # convolution module + residual = src + if self.normalize_before: + src = self.norm_conv(src) + src = residual + self.dropout(self.conv_module(src)) + if not self.normalize_before: + src = self.norm_conv(src) + + # feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) + + if self.normalize_before: + src = self.norm_final(src) + + return src + + +class ConformerEncoder(nn.TransformerEncoder): + r"""ConformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = conformer_encoder(src, pos_emb) + """ + + def __init__( + self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None + ) -> None: + super(ConformerEncoder, self).__init__( + encoder_layer=encoder_layer, num_layers=num_layers, norm=norm + ) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for mod in self.layers: + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_espnet_structure: bool = False, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + + self._reset_parameters() + + self.is_espnet_structure = is_espnet_structure + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.weight, + self.in_proj.bias, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + ) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + assert n == 2 * time1 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time1), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + if not self.is_espnet_structure: + q = q * scaling + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + q_with_bias_u = (q + self.pos_bias_u).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self.pos_bias_v).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p.transpose(-2, -1) + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd) + + if not self.is_espnet_structure: + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) + else: + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.norm = nn.BatchNorm1d(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward(self, x: Tensor) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def identity(x): + return x diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py new file mode 100755 index 000000000..74fd3060b --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) + +# (still working in progress) + +import argparse +import logging +from pathlib import Path + +import torch +from conformer import Conformer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.dataset.librispeech import LibriSpeechAsrDataModule +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "exp_dir": Path("conformer_ctc/exp"), + "lang_dir": Path("data/lang/bpe"), + "lm_dir": Path("data/lm"), + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_classes": 5000, + "subsampling_factor": 4, + "num_decoder_layers": 6, + "vgg_frontend": False, + "is_espnet_structure": True, + "mmi_loss": False, + "use_feat_batchnorm": True, + "search_beam": 20, + "output_beam": 5, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + # Possible values for method: + # - 1best + # - nbest + # - nbest-rescoring + # - whole-lattice-rescoring + "method": "whole-lattice-rescoring", + # num_paths is used when method is "nbest" and "nbest-rescoring" + "num_paths": 30, + } + ) + return params + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=9, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log/log-decode") + logging.info("Decoding started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=params.num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + is_espnet_structure=params.is_espnet_structure, + mmi_loss=params.mmi_loss, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to(device) + model.eval() + token_ids_with_blank = list(range(params.num_classes)) + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py new file mode 100644 index 000000000..e302cfeaf --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -0,0 +1,1100 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Apache 2.0 + +import math +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +from torch import Tensor, nn + +from icefall.utils import get_texts + +# Note: TorchScript requires Dict/List/etc. to be fully typed. +Supervisions = Dict[str, Tensor] + + +class Transformer(nn.Module): + """ + Args: + num_features (int): Number of input features + num_classes (int): Number of output classes + subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) + d_model (int): attention dimension + nhead (int): number of head + dim_feedforward (int): feedforward dimention + num_encoder_layers (int): number of encoder layers + num_decoder_layers (int): number of decoder layers + dropout (float): dropout rate + normalize_before (bool): whether to use layer_norm before the first block. + vgg_frontend (bool): whether to use vgg frontend. + """ + + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + normalize_before: bool = True, + vgg_frontend: bool = False, + mmi_loss: bool = True, + use_feat_batchnorm: bool = False, + ) -> None: + super().__init__() + self.use_feat_batchnorm = use_feat_batchnorm + if use_feat_batchnorm: + self.feat_batchnorm = nn.BatchNorm1d(num_features) + + self.num_features = num_features + self.num_classes = num_classes + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + self.encoder_embed = ( + VggSubsampling(num_features, d_model) + if vgg_frontend + else Conv2dSubsampling(num_features, d_model) + ) + self.encoder_pos = PositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + encoder_norm = nn.LayerNorm(d_model) + else: + encoder_norm = None + + self.encoder = nn.TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + ) + + if num_decoder_layers > 0: + if mmi_loss: + self.decoder_num_class = ( + self.num_classes + 1 + ) # +1 for the sos/eos symbol + else: + self.decoder_num_class = ( + self.num_classes + ) # bpe model already has sos/eos symbol + + self.decoder_embed = nn.Embedding(self.decoder_num_class, d_model) + self.decoder_pos = PositionalEncoding(d_model, dropout) + + decoder_layer = TransformerDecoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + decoder_norm = nn.LayerNorm(d_model) + else: + decoder_norm = None + + self.decoder = nn.TransformerDecoder( + decoder_layer, num_decoder_layers, decoder_norm + ) + + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) + + self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) + else: + self.decoder_criterion = None + + def forward( + self, x: Tensor, supervision: Optional[Supervisions] = None + ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + """ + Args: + x: Tensor of dimension (batch_size, num_features, input_length). + supervision: Supervison in lhotse format, get from batch['supervisions'] + + Returns: + Tensor: After log-softmax tensor of dimension (batch_size, number_of_classes, input_length). + Tensor: Before linear layer tensor of dimension (input_length, batch_size, d_model). + Optional[Tensor]: Mask tensor of dimension (batch_size, input_length) or None. + + """ + if self.use_feat_batchnorm: + x = self.feat_batchnorm(x) + encoder_memory, memory_mask = self.encode(x, supervision) + x = self.encoder_output(encoder_memory) + return x, encoder_memory, memory_mask + + def encode( + self, x: Tensor, supervisions: Optional[Supervisions] = None + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x: Tensor of dimension (batch_size, num_features, input_length). + supervisions : Supervison in lhotse format, i.e., batch['supervisions'] + + Returns: + Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). + Optional[Tensor]: Mask tensor of dimension (batch_size, input_length) or None. + """ + x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F) + + x = self.encoder_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + mask = encoder_padding_mask(x.size(0), supervisions) + mask = mask.to(x.device) if mask != None else None + x = self.encoder(x, src_key_padding_mask=mask) # (T, B, F) + + return x, mask + + def encoder_output(self, x: Tensor) -> Tensor: + """ + Args: + x: Tensor of dimension (input_length, batch_size, d_model). + + Returns: + Tensor: After log-softmax tensor of dimension (batch_size, number_of_classes, input_length). + """ + x = self.encoder_output_layer(x).permute( + 1, 2, 0 + ) # (T, B, F) ->(B, F, T) + x = nn.functional.log_softmax(x, dim=1) # (B, F, T) + return x + + def decoder_forward( + self, + x: Tensor, + encoder_mask: Tensor, + supervision: Supervisions = None, + graph_compiler: object = None, + token_ids: List[int] = None, + ) -> Tensor: + """ + Args: + x: Tensor of dimension (input_length, batch_size, d_model). + encoder_mask: Mask tensor of dimension (batch_size, input_length) + supervision: Supervison in lhotse format, get from batch['supervisions'] + graph_compiler: use graph_compiler.L_inv (Its labels are words, while its aux_labels are phones) + , graph_compiler.words and graph_compiler.oov + + Returns: + Tensor: Decoder loss. + """ + if supervision is not None and graph_compiler is not None: + batch_text = get_normal_transcripts( + supervision, graph_compiler.lexicon.words, graph_compiler.oov + ) + ys_in_pad, ys_out_pad = add_sos_eos( + batch_text, + graph_compiler.L_inv, + self.decoder_num_class - 1, + self.decoder_num_class - 1, + ) + elif token_ids is not None: + # speical token ids: + # 0 + # 1 + # self.decoder_num_class - 1 + sos_id = self.decoder_num_class - 1 + eos_id = self.decoder_num_class - 1 + _sos = torch.tensor([sos_id]) + _eos = torch.tensor([eos_id]) + ys_in = [ + torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids + ] + ys_out = [ + torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids + ] + ys_in_pad = pad_list(ys_in, eos_id) + ys_out_pad = pad_list(ys_out, -1) + + else: + raise ValueError("Invalid input for decoder self attetion") + + ys_in_pad = ys_in_pad.to(x.device) + ys_out_pad = ys_out_pad.to(x.device) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + x.device + ) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) + + tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + pred_pad = self.decoder( + tgt=tgt, + memory=x, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=encoder_mask, + ) # (T, B, F) + pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) + pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + + decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) + + return decoder_loss + + def decoder_nll( + self, x: Tensor, encoder_mask: Tensor, token_ids: List[int] = None + ) -> Tensor: + """ + Args: + x: encoder-output, Tensor of dimension (input_length, batch_size, d_model). + encoder_mask: Mask tensor of dimension (batch_size, input_length) + token_ids: n-best list extracted from lattice before rescore + + Returns: + Tensor: negative log-likelihood. + """ + # The common part between this fuction and decoder_forward could be + # extracted as a seperated function. + if token_ids is not None: + # speical token ids: + # 0 + # 1 + # self.decoder_num_class - 1 + sos_id = self.decoder_num_class - 1 + eos_id = self.decoder_num_class - 1 + _sos = torch.tensor([sos_id]) + _eos = torch.tensor([eos_id]) + ys_in = [ + torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids + ] + ys_out = [ + torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids + ] + ys_in_pad = pad_list(ys_in, eos_id) + ys_out_pad = pad_list(ys_out, -1) + else: + raise ValueError("Invalid input for decoder self attetion") + + ys_in_pad = ys_in_pad.to(x.device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(x.device, dtype=torch.int64) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + x.device + ) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) + + tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + pred_pad = self.decoder( + tgt=tgt, + memory=x, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=encoder_mask, + ) # (T, B, F) + pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) + pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred_pad.view(-1, self.decoder_num_class), + ys_out_pad.view(-1), + ignore_index=-1, + reduction="none", + ) + + nll = nll.view(pred_pad.shape[0], -1) + + return nll + + +class TransformerEncoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerEncoderLayer. Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + normalize_before: whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerEncoderLayer, self).__setstate__(state) + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + """ + residual = src + if self.normalize_before: + src = self.norm1(src) + src2 = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout1(src2) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src2) + if not self.normalize_before: + src = self.norm2(src) + return src + + +class TransformerDecoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerDecoderLayer. Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerDecoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerDecoderLayer, self).__setstate__(state) + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + tgt: (T, N, E). + memory: (S, N, E). + tgt_mask: (T, T). + memory_mask: (T, S). + tgt_key_padding_mask: (N, T). + memory_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + tgt2 = self.self_attn( + tgt, + tgt, + tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask, + )[0] + tgt = residual + self.dropout1(tgt2) + if not self.normalize_before: + tgt = self.norm1(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm2(tgt) + tgt2 = self.src_attn( + tgt, + memory, + memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = residual + self.dropout2(tgt2) + if not self.normalize_before: + tgt = self.norm2(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = residual + self.dropout3(tgt2) + if not self.normalize_before: + tgt = self.norm3(tgt) + return tgt + + +def _get_activation_fn(activation: str): + if activation == "relu": + return nn.functional.relu + elif activation == "gelu": + return nn.functional.gelu + + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py + + Args: + idim: Input dimension. + odim: Output dimension. + + """ + + def __init__(self, idim: int, odim: int) -> None: + """Construct a Conv2dSubsampling object.""" + super(Conv2dSubsampling, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), + nn.ReLU(), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), + nn.ReLU(), + ) + self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + + def forward(self, x: Tensor) -> Tensor: + """Subsample x. + + Args: + x: Input tensor of dimension (batch_size, input_length, num_features). (#batch, time, idim). + + Returns: + torch.Tensor: Subsampled tensor of dimension (batch_size, input_length, d_model). + where time' = time // 4. + + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + return x + + +class VggSubsampling(nn.Module): + """Trying to follow the setup described here https://arxiv.org/pdf/1910.09799.pdf + This paper is not 100% explicit so I am guessing to some extent, + and trying to compare with other VGG implementations. + + Args: + idim: Input dimension. + odim: Output dimension. + + """ + + def __init__(self, idim: int, odim: int) -> None: + """Construct a VggSubsampling object. This uses 2 VGG blocks with 2 + Conv2d layers each, subsampling its input by a factor of 4 in the + time dimensions. + + Args: + idim: Number of features at input, e.g. 40 or 80 for MFCC + (will be treated as the image height). + odim: Output dimension (number of features), e.g. 256 + """ + super(VggSubsampling, self).__init__() + + cur_channels = 1 + layers = [] + block_dims = [32, 64] + + # The decision to use padding=1 for the 1st convolution, then padding=0 + # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by + # a back-compatibility concern so that the number of frames at the + # output would be equal to: + # (((T-1)//2)-1)//2. + # We can consider changing this by using padding=1 on the 2nd convolution, + # so the num-frames at the output would be T//4. + for block_dim in block_dims: + layers.append( + torch.nn.Conv2d( + in_channels=cur_channels, + out_channels=block_dim, + kernel_size=3, + padding=1, + stride=1, + ) + ) + layers.append(torch.nn.ReLU()) + layers.append( + torch.nn.Conv2d( + in_channels=block_dim, + out_channels=block_dim, + kernel_size=3, + padding=0, + stride=1, + ) + ) + layers.append( + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) + ) + cur_channels = block_dim + + self.layers = nn.Sequential(*layers) + + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) + + def forward(self, x: Tensor) -> Tensor: + """Subsample x. + + Args: + x: Input tensor of dimension (batch_size, input_length, num_features). (#batch, time, idim). + + Returns: + torch.Tensor: Subsampled tensor of dimension (batch_size, input_length', d_model). + where input_length' == (((input_length - 1) // 2) - 1) // 2 + + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.layers(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + return x + + +class PositionalEncoding(nn.Module): + """ + Positional encoding. + + Args: + d_model: Embedding dimension. + dropout: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, d_model: int, dropout: float = 0.1, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: Tensor) -> Tensor: + """ + Add positional encoding. + + Args: + x: Input tensor of dimention (batch_size, input_length, d_model). + + Returns: + torch.Tensor: Encoded tensor of dimention (batch_size, input_length, d_model). + + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class Noam(object): + """ + Implements Noam optimizer. Proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py + + Args: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + model_size: attention dimension of the transformer model + factor: learning rate factor + warm_step: warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + +class LabelSmoothingLoss(nn.Module): + """ + Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w) + and p_{prob. computed by model}(w) is minimized. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py + + Args: + size: the number of class + padding_idx: padding_idx: ignored class id + smoothing: smoothing rate (0.0 means the conventional CE) + normalize_length: normalize loss by sequence length if True + criterion: loss function to be smoothed + """ + + def __init__( + self, + size: int, + padding_idx: int = -1, + smoothing: float = 0.1, + normalize_length: bool = False, + criterion: nn.Module = nn.KLDivLoss(reduction="none"), + ) -> None: + """Construct an LabelSmoothingLoss object.""" + super(LabelSmoothingLoss, self).__init__() + self.criterion = criterion + self.padding_idx = padding_idx + assert 0.0 < smoothing <= 1.0 + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.size = size + self.true_dist = None + self.normalize_length = normalize_length + + def forward(self, x: Tensor, target: Tensor) -> Tensor: + """ + Compute loss between x and target. + + Args: + x: prediction of dimention (batch_size, input_length, number_of_classes). + target: target masked with self.padding_id of dimention (batch_size, input_length). + + Returns: + torch.Tensor: scalar float value + """ + assert x.size(2) == self.size + batch_size = x.size(0) + x = x.view(-1, self.size) + target = target.view(-1) + with torch.no_grad(): + true_dist = x.clone() + true_dist.fill_(self.smoothing / (self.size - 1)) + ignore = target == self.padding_idx # (B,) + total = len(target) - ignore.sum().item() + target = target.masked_fill(ignore, 0) # avoid -1 index + true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) + denom = total if self.normalize_length else batch_size + return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom + + +def encoder_padding_mask( + max_len: int, supervisions: Optional[Supervisions] = None +) -> Optional[Tensor]: + """Make mask tensor containing indices of padded part. + + Args: + max_len: maximum length of input features + supervisions : Supervison in lhotse format, i.e., batch['supervisions'] + + Returns: + Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices. + """ + if supervisions is None: + return None + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"], + supervisions["num_frames"], + ), + 1, + ).to(torch.int32) + + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] + for idx in range(supervision_segments.size(0)): + # Note: TorchScript doesn't allow to unpack tensors as tuples + sequence_idx = supervision_segments[idx, 0].item() + start_frame = supervision_segments[idx, 1].item() + num_frames = supervision_segments[idx, 2].item() + lengths[sequence_idx] = start_frame + num_frames + + lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] + bs = int(len(lengths)) + seq_range = torch.arange(0, max_len, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) + # Note: TorchScript doesn't implement Tensor.new() + seq_length_expand = torch.tensor( + lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype + ).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + return mask + + +def decoder_padding_mask(ys_pad: Tensor, ignore_id: int = -1) -> Tensor: + """Generate a length mask for input. The masked position are filled with bool(True), + Unmasked positions are filled with bool(False). + + Args: + ys_pad: padded tensor of dimension (batch_size, input_length). + ignore_id: the ignored number (the padding number) in ys_pad + + Returns: + Tensor: a mask tensor of dimension (batch_size, input_length). + """ + ys_mask = ys_pad == ignore_id + return ys_mask + + +def get_normal_transcripts( + supervision: Supervisions, words: k2.SymbolTable, oov: str = "" +) -> List[List[int]]: + """Get normal transcripts (1 input recording has 1 transcript) from lhotse cut format. + Achieved by concatenate the transcripts corresponding to the same recording. + + Args: + supervision : Supervison in lhotse format, i.e., batch['supervisions'] + words: The word symbol table. + oov: Out of vocabulary word. + + Returns: + List[List[int]]: List of concatenated transcripts, length is batch_size + """ + + texts = [ + [token if token in words else oov for token in text.split(" ")] + for text in supervision["text"] + ] + texts_ids = [[words[token] for token in text] for text in texts] + + batch_text = [ + [] for _ in range(int(supervision["sequence_idx"].max().item()) + 1) + ] + for sequence_idx, text in zip(supervision["sequence_idx"], texts_ids): + batch_text[sequence_idx] = batch_text[sequence_idx] + text + return batch_text + + +def generate_square_subsequent_mask(sz: int) -> Tensor: + """Generate a square mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + + Args: + sz: mask size + + Returns: + Tensor: a square mask of dimension (sz, sz) + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask + + +def add_sos_eos( + ys: List[List[int]], + lexicon: k2.Fsa, + sos: int, + eos: int, + ignore_id: int = -1, +) -> Tuple[Tensor, Tensor]: + """Add and labels. + + Args: + ys: batch of unpadded target sequences + lexicon: Its labels are words, while its aux_labels are phones. + sos: index of + eos: index of + ignore_id: index of padding + + Returns: + Tensor: Input of transformer decoder. Padded tensor of dimention (batch_size, max_length). + Tensor: Output of transformer decoder. padded tensor of dimention (batch_size, max_length). + """ + + _sos = torch.tensor([sos]) + _eos = torch.tensor([eos]) + ys = get_hierarchical_targets(ys, lexicon) + ys_in = [torch.cat([_sos, y], dim=0) for y in ys] + ys_out = [torch.cat([y, _eos], dim=0) for y in ys] + return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) + + +def pad_list(ys: List[Tensor], pad_value: float) -> Tensor: + """Perform padding for the list of tensors. + + Args: + ys: List of tensors. len(ys) = batch_size. + pad_value: Value for padding. + + Returns: + Tensor: Padded tensor (batch_size, max_length, `*`). + + Examples: + >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] + >>> x + [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] + >>> pad_list(x, 0) + tensor([[1., 1., 1., 1.], + [1., 1., 0., 0.], + [1., 0., 0., 0.]]) + + """ + n_batch = len(ys) + max_len = max(x.size(0) for x in ys) + pad = ys[0].new_full((n_batch, max_len, *ys[0].size()[1:]), pad_value) + + for i in range(n_batch): + pad[i, : ys[i].size(0)] = ys[i] + + return pad + + +def get_hierarchical_targets( + ys: List[List[int]], lexicon: k2.Fsa +) -> List[Tensor]: + """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts). + + Args: + ys: Word level transcripts. + lexicon: Its labels are words, while its aux_labels are phones. + + Returns: + List[Tensor]: Phone level transcripts. + + """ + + if lexicon is None: + return ys + else: + L_inv = lexicon + + n_batch = len(ys) + device = L_inv.device + + transcripts = k2.create_fsa_vec( + [k2.linear_fsa(x, device=device) for x in ys] + ) + transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts) + + transcripts_lexicon = k2.intersect( + L_inv, transcripts_with_self_loops, treat_epsilons_specially=False + ) + # Don't call invert_() above because we want to return phone IDs, + # which is the `aux_labels` of transcripts_lexicon + transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon) + transcripts_lexicon = k2.top_sort(transcripts_lexicon) + + transcripts_lexicon = k2.shortest_path( + transcripts_lexicon, use_double_scores=True + ) + + ys = get_texts(transcripts_lexicon) + ys = [torch.tensor(y) for y in ys] + + return ys + + +def test_transformer(): + t = Transformer(40, 1281) + T = 200 + f = torch.rand(31, 40, T) + g, _, _ = t(f) + assert g.shape == (31, 1281, (((T - 1) // 2) - 1) // 2) + + +def main(): + test_transformer() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/local/__init__.py b/egs/librispeech/ASR/local/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index a4ec93728..c30bf9fba 100644 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -3,28 +3,48 @@ """ This script compiles HLG from - - H, the ctc topology, built from phones contained in data/lang/lexicon.txt - - L, the lexicon, built from data/lang/L_disambig.pt + - H, the ctc topology, built from phones contained in lexicon.txt + - L, the lexicon, built from L_disambig.pt Caution: We use a lexicon that contains disambiguation symbols - G, the LM, built from data/lm/G_3_gram.fst.txt -The generated HLG is saved in data/lm/HLG.pt +The generated HLG is saved in data/lm/HLG.pt (phone based) +or data/lm/HLG_bpe.pt (BPE based) """ +from pathlib import Path + import k2 import torch from icefall.lexicon import Lexicon -def main(): - lexicon = Lexicon("data/lang") +def compile_HLG(lang_dir: str) -> k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang or data/lang/bpe. + + Return: + An FSA representing HLG. + """ + lexicon = Lexicon(lang_dir) max_token_id = max(lexicon.tokens) + print(f"building ctc_top. max_token_id: {max_token_id}") H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load("data/lang/L_disambig.pt")) - with open("data/lm/G_3_gram.fst.txt") as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=False) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + if Path("data/lm/G_3_gram.pt").is_file(): + print("Loading pre-compiled G_3_gram") + d = torch.load("data/lm/G_3_gram.pt") + G = k2.Fsa.from_dict(d).to(device) + else: + print("Loading G_3_gram.fst.txt") + with open("data/lm/G_3_gram.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + torch.save(G.as_dict(), "G_3_gram.pt") first_token_disambig_id = lexicon.phones["#0"] first_word_disambig_id = lexicon.words["#0"] @@ -74,9 +94,34 @@ def main(): print("Arc sorting LG") HLG = k2.arc_sort(HLG) + return HLG + + +def phone_based_HLG(): + if Path("data/lm/HLG.pt").is_file(): + return + + print("Compiling phone based HLG") + HLG = compile_HLG("data/lang") + print("Saving HLG.pt to data/lm") torch.save(HLG.as_dict(), "data/lm/HLG.pt") +def bpe_based_HLG(): + if Path("data/lm/HLG_bpe.pt").is_file(): + return + + print("Compiling BPE based HLG") + HLG = compile_HLG("data/lang/bpe") + print("Saving HLG.pt to data/lm") + torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt") + + +def main(): + phone_based_HLG() + bpe_based_HLG() + + if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/local/prepare_lang.py b/egs/librispeech/ASR/local/prepare_lang.py index f9864bd02..9945a5006 100755 --- a/egs/librispeech/ASR/local/prepare_lang.py +++ b/egs/librispeech/ASR/local/prepare_lang.py @@ -48,7 +48,7 @@ def read_lexicon(filename: str) -> Lexicon: """ ans = [] - with open(filename, "r", encoding="latin-1") as f: + with open(filename, "r", encoding="utf-8") as f: whitespace = re.compile("[ \t]+") for line in f: a = whitespace.split(line.strip(" \t\r\n")) @@ -80,7 +80,7 @@ def write_lexicon(filename: str, lexicon: Lexicon) -> None: lexicon: It can be the return value of :func:`read_lexicon`. """ - with open(filename, "w") as f: + with open(filename, "w", encoding="utf-8") as f: for word, prons in lexicon: f.write(f"{word} {' '.join(prons)}\n") @@ -100,7 +100,7 @@ def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: Returns: Return None. """ - with open(filename, "w") as f: + with open(filename, "w", encoding="utf-8") as f: for sym, i in sym2id.items(): f.write(f"{sym} {i}\n") @@ -151,7 +151,8 @@ def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: Return a tuple with two elements: - The output lexicon with disambiguation symbols - - The ID of the max disambiguation symbols + - The ID of the max disambiguation symbols that appears + in the lexicon """ # (1) Work out the count of each phone-sequence in the diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py new file mode 100755 index 000000000..f70279cf4 --- /dev/null +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +""" +This script takes as inputs the following files: + - data/lang/bpe/bpe.model, + - data/lang/bpe/tokens.txt (will remove it), + - data/lang/bpe/words.txt + +and generates the following files in the directory data/lang/bpe: + + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - phones.txt +""" + +from pathlib import Path +from typing import Dict, List + +import k2 +import sentencepiece as spm +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, +) + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go. + + arcs = [] + + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, prons in lexicon: + assert len(prons) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + prons = [token2id[i] for i in prons] + + for i in range(len(prons) - 1): + if i == 0: + arcs.append([cur_state, next_state, prons[i], word, 0]) + else: + arcs.append([cur_state, next_state, prons[i], eps, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last phone of this word + i = len(prons) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, prons[i], w, 0]) + + if need_self_loops: + disambig_phone = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_phone=disambig_phone, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def generate_lexicon(model_file: str, words: List[str]) -> Lexicon: + """Generate a lexicon from a BPE model. + + Args: + model_file: + Path to a sentencepiece model. + words: + A list of strings representing words. + Returns: + Return a dict whose keys are words and values are the corresponding + word pieces. + """ + sp = spm.SentencePieceProcessor() + sp.load(str(model_file)) + + words_pieces: List[List[str]] = sp.encode(words, out_type=str) + + lexicon = [] + for word, pieces in zip(words, words_pieces): + lexicon.append((word, pieces)) + + lexicon.append(("", [""])) + return lexicon + + +def main(): + lang_dir = Path("data/lang/bpe") + model_file = lang_dir / "bpe.model" + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = ["", "!SIL", "", "", "#0", "", ""] + for w in excluded: + if w in words: + words.remove(w) + + lexicon = generate_lexicon(model_file, words) + + # TODO(fangjun): Remove tokens.txt and generate it from the model directly. + # + # We are using it since the IDs we are using in tokens.txt is + # different from the one contained in the model + token_sym_table = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table.add(f"#{i}") + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + token_sym_table.to_file(lang_dir / "phones.txt") + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if False: + # Just for debugging, will remove it + L.labels_sym = k2.SymbolTable.from_file(lang_dir / "phones.txt") + L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + L_disambig.labels_sym = L.labels_sym + L_disambig.aux_labels_sym = L.aux_labels_sym + L.draw(lang_dir / "L.svg", title="L") + L_disambig.draw(lang_dir / "L_disambig.svg", title="L_disambig") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 87786c5c8..b73d0e71f 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -89,7 +89,17 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - echo "Stage 6: Prepare G" + echo "State 6: Prepare BPE based lang" + mkdir -p data/lang/bpe + cp data/lang/words.txt data/lang/bpe/ + + if [ ! -f data/lang/bpe/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + echo "Stage 7: Prepare G" # We assume you have install kaldilm, if not, please install # it using: pip install kaldilm @@ -112,9 +122,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then fi fi -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - echo "Stage 7: Compile HLG" - if [ ! -f data/lm/HLG.pt ]; then - python3 ./local/compile_hlg.py - fi +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + echo "Stage 8: Compile HLG" + python3 ./local/compile_hlg.py fi diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..a54edf118 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +kaldilm +kaldialign +sentencepiece>=0.1.96