From f6091b10c09ef7c32c94f1d426758e698f6db056 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 2 Aug 2021 23:48:26 +0800 Subject: [PATCH] Refactor transformer.py --- .../ASR/conformer_ctc/conformer.py | 24 +- egs/librispeech/ASR/conformer_ctc/decode.py | 12 +- .../ASR/conformer_ctc/subsampling.py | 144 ++++ .../ASR/conformer_ctc/test_subsampling.py | 33 + .../ASR/conformer_ctc/test_transformer.py | 36 + egs/librispeech/ASR/conformer_ctc/train.py | 23 +- .../ASR/conformer_ctc/transformer.py | 761 +++++++++--------- .../ASR/local/compute_fbank_librispeech.py | 10 +- .../ASR/local/compute_fbank_musan.py | 7 + icefall/decode.py | 39 +- 10 files changed, 689 insertions(+), 400 deletions(-) create mode 100644 egs/librispeech/ASR/conformer_ctc/subsampling.py create mode 100755 egs/librispeech/ASR/conformer_ctc/test_subsampling.py create mode 100644 egs/librispeech/ASR/conformer_ctc/test_transformer.py diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 1e82eff2f..d3952d3b1 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -89,15 +89,21 @@ class Conformer(Transformer): ) -> Tuple[Tensor, Optional[Tensor]]: """ Args: - x: Tensor of dimension (batch_size, num_features, input_length). - supervisions : Supervison in lhotse format, i.e., batch['supervisions'] + x: + The model input. Its shape is [N, T, C]. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. Returns: Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). Tensor: Mask tensor of dimension (batch_size, input_length) """ - x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F) - x = self.encoder_embed(x) x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) @@ -796,8 +802,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz, num_heads, tgt_len, src_len ) attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), + key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), ) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len @@ -867,12 +872,7 @@ class ConvolutionModule(nn.Module): ) self.norm = nn.BatchNorm1d(channels) self.pointwise_conv2 = nn.Conv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, + channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, ) self.activation = Swish() diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 3a8db1b81..9ebb76fa1 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -147,15 +147,10 @@ def decode_one_batch( feature = feature.to(device) # at entry, feature is [N, T, C] - feature = feature.permute(0, 2, 1) # now feature is [N, C, T] - supervisions = batch["supervisions"] nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) - # nnet_output is [N, C, T] - - nnet_output = nnet_output.permute(0, 2, 1) - # now nnet_output is [N, T, C] + # nnet_output is [N, T, C] supervision_segments = torch.stack( ( @@ -227,6 +222,8 @@ def decode_one_batch( model=model, memory=memory, memory_key_padding_mask=memory_key_padding_mask, + sos_id=lexicon.sos_id, + eos_id=lexicon.eos_id, ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -468,5 +465,8 @@ def main(): logging.info("Done!") +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py new file mode 100644 index 000000000..5c3e1222e --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape [N, T, idim] to an output + with shape [N, T', odim], where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__(self, idim: int, odim: int) -> None: + """ + Args: + idim: + Input dim. The input shape is [N, T, idim]. + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] + """ + assert idim >= 7 + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), + nn.ReLU(), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), + nn.ReLU(), + ) + self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is [N, T, idim]. + + Returns: + Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] + """ + # On entry, x is [N, T, idim] + x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W] + x = self.conv(x) + # Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2] + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape [N, ((T-1)//2 - 1))//2, odim] + return x + + +class VggSubsampling(nn.Module): + """Trying to follow the setup described in the following paper: + https://arxiv.org/pdf/1910.09799.pdf + + This paper is not 100% explicit so I am guessing to some extent, + and trying to compare with other VGG implementations. + + Convert an input of shape [N, T, idim] to an output + with shape [N, T', odim], where + T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 + """ + + def __init__(self, idim: int, odim: int) -> None: + """Construct a VggSubsampling object. + + This uses 2 VGG blocks with 2 Conv2d layers each, + subsampling its input by a factor of 4 in the time dimensions. + + Args: + idim: + Input dim. The input shape is [N, T, idim]. + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] + """ + super().__init__() + + cur_channels = 1 + layers = [] + block_dims = [32, 64] + + # The decision to use padding=1 for the 1st convolution, then padding=0 + # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by + # a back-compatibility concern so that the number of frames at the + # output would be equal to: + # (((T-1)//2)-1)//2. + # We can consider changing this by using padding=1 on the + # 2nd convolution, so the num-frames at the output would be T//4. + for block_dim in block_dims: + layers.append( + torch.nn.Conv2d( + in_channels=cur_channels, + out_channels=block_dim, + kernel_size=3, + padding=1, + stride=1, + ) + ) + layers.append(torch.nn.ReLU()) + layers.append( + torch.nn.Conv2d( + in_channels=block_dim, + out_channels=block_dim, + kernel_size=3, + padding=0, + stride=1, + ) + ) + layers.append( + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) + ) + cur_channels = block_dim + + self.layers = nn.Sequential(*layers) + + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is [N, T, idim]. + + Returns: + Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] + """ + x = x.unsqueeze(1) + x = self.layers(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + return x diff --git a/egs/librispeech/ASR/conformer_ctc/test_subsampling.py b/egs/librispeech/ASR/conformer_ctc/test_subsampling.py new file mode 100755 index 000000000..937845d77 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/test_subsampling.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +from subsampling import Conv2dSubsampling +from subsampling import VggSubsampling +import torch + + +def test_conv2d_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = Conv2dSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim + + +def test_vgg_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = VggSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim diff --git a/egs/librispeech/ASR/conformer_ctc/test_transformer.py b/egs/librispeech/ASR/conformer_ctc/test_transformer.py new file mode 100644 index 000000000..a6569e8d7 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/test_transformer.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 + +import torch +from transformer import Transformer, encoder_padding_mask + + +def test_encoder_padding_mask(): + supervisions = { + "sequence_idx": torch.tensor([0, 1, 2]), + "start_frame": torch.tensor([0, 0, 0]), + "num_frames": torch.tensor([18, 7, 13]), + } + + max_len = ((18 - 1) // 2 - 1) // 2 + mask = encoder_padding_mask(max_len, supervisions) + expected_mask = torch.tensor( + [ + [False, False, False], # ((18 - 1)//2 - 1)//2 = 3, + [False, True, True], # ((7 - 1)//2 - 1)//2 = 1, + [False, False, True], # ((13 - 1)//2 - 1)//2 = 2, + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_transformer(): + num_features = 40 + num_classes = 87 + model = Transformer(num_features=num_features, num_classes=num_classes) + + N = 31 + + for T in range(7, 30): + x = torch.rand(N, T, num_features) + y, _, _ = model(x) + assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index d411a3783..552db81ec 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -275,15 +275,13 @@ def compute_loss( device = graph_compiler.device feature = batch["inputs"] # at entry, feature is [N, T, C] - feature = feature.permute(0, 2, 1) # now feature is [N, C, T] assert feature.ndim == 3 feature = feature.to(device) supervisions = batch["supervisions"] with torch.set_grad_enabled(is_training): nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is [N, C, T] - nnet_output = nnet_output.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + # nnet_output is [N, T, C] # NOTE: We need `encode_supervisions` to sort sequences with # different duration in decreasing order, required by @@ -536,6 +534,22 @@ def train_one_epoch( f" best valid loss: {params.best_valid_loss:.4f} " f"best valid epoch: {params.best_valid_epoch}" ) + if tb_writer is not None: + tb_writer.add_scalar( + "train/valid_ctc_loss", + params.valid_ctc_loss, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/valid_att_loss", + params.valid_att_loss, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/valid_loss", + params.valid_loss, + params.batch_idx_train, + ) params.train_loss = tot_loss / tot_frames @@ -675,5 +689,8 @@ def main(): run(rank=0, world_size=1, args=args) +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 06027cf64..b2123b8fc 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -1,6 +1,4 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) # Apache 2.0 import math @@ -8,30 +6,16 @@ from typing import Dict, List, Optional, Tuple import k2 import torch -from torch import Tensor, nn +import torch.nn as nn +from subsampling import Conv2dSubsampling, VggSubsampling from icefall.utils import get_texts # Note: TorchScript requires Dict/List/etc. to be fully typed. -Supervisions = Dict[str, Tensor] +Supervisions = Dict[str, torch.Tensor] class Transformer(nn.Module): - """ - Args: - num_features (int): Number of input features - num_classes (int): Number of output classes - subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension - nhead (int): number of head - dim_feedforward (int): feedforward dimention - num_encoder_layers (int): number of encoder layers - num_decoder_layers (int): number of decoder layers - dropout (float): dropout rate - normalize_before (bool): whether to use layer_norm before the first block. - vgg_frontend (bool): whether to use vgg frontend. - """ - def __init__( self, num_features: int, @@ -48,6 +32,36 @@ class Transformer(nn.Module): mmi_loss: bool = True, use_feat_batchnorm: bool = False, ) -> None: + """ + Args: + num_features: + The input dimension of the model. + num_classes: + The output dimension of the model. + subsampling_factor: + Number of output frames is num_in_frames // subsampling_factor. + Currently, subsampling_factor MUST be 4. + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder/decoder. + num_encoder_layers: + Number of encoder layers. + num_decoder_layers: + Number of decoder layers. + dropout: + Dropout in encoder/decoder. + normalize_before: + If True, use pre-layer norm; False to use post-layer norm. + vgg_frontend: + True to use vgg style frontend for subsampling. + mmi_loss: + use_feat_batchnorm: + True to use batchnorm for the input layer. + """ super().__init__() self.use_feat_batchnorm = use_feat_batchnorm if use_feat_batchnorm: @@ -59,18 +73,23 @@ class Transformer(nn.Module): if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - self.encoder_embed = ( - VggSubsampling(num_features, d_model) - if vgg_frontend - else Conv2dSubsampling(num_features, d_model) - ) + # self.encoder_embed converts the input of shape [N, T, num_classes] + # to the shape [N, T//subsampling_factor, d_model]. + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_classes -> d_model + if vgg_frontend: + self.encoder_embed = VggSubsampling(num_features, d_model) + else: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_pos = PositionalEncoding(d_model, dropout) encoder_layer = TransformerEncoderLayer( - d_model, - nhead, - dim_feedforward, - dropout, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, normalize_before=normalize_before, ) @@ -80,9 +99,12 @@ class Transformer(nn.Module): encoder_norm = None self.encoder = nn.TransformerEncoder( - encoder_layer, num_encoder_layers, encoder_norm + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + norm=encoder_norm, ) + # TODO(fangjun): remove dropout self.encoder_output_layer = nn.Sequential( nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) ) @@ -97,14 +119,16 @@ class Transformer(nn.Module): self.num_classes ) # bpe model already has sos/eos symbol - self.decoder_embed = nn.Embedding(self.decoder_num_class, d_model) + self.decoder_embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model + ) self.decoder_pos = PositionalEncoding(d_model, dropout) decoder_layer = TransformerDecoderLayer( - d_model, - nhead, - dim_feedforward, - dropout, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, normalize_before=normalize_before, ) @@ -114,7 +138,9 @@ class Transformer(nn.Module): decoder_norm = None self.decoder = nn.TransformerDecoder( - decoder_layer, num_decoder_layers, decoder_norm + decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + norm=decoder_norm, ) self.decoder_output_layer = torch.nn.Linear( @@ -126,93 +152,145 @@ class Transformer(nn.Module): self.decoder_criterion = None def forward( - self, x: Tensor, supervision: Optional[Supervisions] = None - ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + self, x: torch.Tensor, supervision: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Args: - x: Tensor of dimension (batch_size, num_features, input_length). - supervision: Supervison in lhotse format, get from batch['supervisions'] + x: + The input tensor. Its shape is [N, T, C]. + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) Returns: - Tensor: After log-softmax tensor of dimension (batch_size, number_of_classes, input_length). - Tensor: Before linear layer tensor of dimension (input_length, batch_size, d_model). - Optional[Tensor]: Mask tensor of dimension (batch_size, input_length) or None. - + Return a tuple containing 3 tensors: + - CTC output for ctc decoding. Its shape is [N, T, C] + - Encoder output with shape [T, N, C]. It can be used as key and + value for the decoder. + - Encoder output padding mask. It can be used as + memory_key_padding_mask for the decoder. Its shape is [N, T]. + It is None if `supervision` is None. """ if self.use_feat_batchnorm: + x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] x = self.feat_batchnorm(x) - encoder_memory, memory_mask = self.encode(x, supervision) + x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + encoder_memory, memory_key_padding_mask = self.encode(x, supervision) x = self.encoder_output(encoder_memory) - return x, encoder_memory, memory_mask + return x, encoder_memory, memory_key_padding_mask def encode( - self, x: Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[Tensor, Optional[Tensor]]: + self, x: torch.Tensor, supervisions: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Args: - x: Tensor of dimension (batch_size, num_features, input_length). - supervisions : Supervison in lhotse format, i.e., batch['supervisions'] - + x: + The model input. Its shape is [N, T, C]. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. Returns: - Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). - Optional[Tensor]: Mask tensor of dimension (batch_size, input_length) or None. + Return a tuple with two tensors: + - The encoder output, with shape [T, N, C] + - encoder padding mask, with shape [N, T]. + The mask is None if `supervisions` is None. + It is used as memory key padding mask in the decoder. """ - x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F) - x = self.encoder_embed(x) x = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(x.size(0), supervisions) - mask = mask.to(x.device) if mask != None else None - x = self.encoder(x, src_key_padding_mask=mask) # (T, B, F) + mask = mask.to(x.device) if mask is not None else None + x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) return x, mask - def encoder_output(self, x: Tensor) -> Tensor: + def encoder_output(self, x: torch.Tensor) -> torch.Tensor: """ Args: - x: Tensor of dimension (input_length, batch_size, d_model). + x: + The output tensor from the transformer encoder. + Its shape is [T, N, C] Returns: - Tensor: After log-softmax tensor of dimension (batch_size, number_of_classes, input_length). + Return a tensor that can be used for CTC decoding. + Its shape is [N, T, C] """ - x = self.encoder_output_layer(x).permute( - 1, 2, 0 - ) # (T, B, F) ->(B, F, T) - x = nn.functional.log_softmax(x, dim=1) # (B, F, T) + x = self.encoder_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) return x def decoder_forward( self, - x: Tensor, - encoder_mask: Tensor, - supervision: Supervisions = None, - graph_compiler: object = None, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + supervision: Optional[Supervisions] = None, + L_inv: Optional[k2.Fsa] = None, + word_table: Optional[k2.SymbolTable] = None, + oov_str: Optional[str] = None, token_ids: List[List[int]] = None, sos_id: Optional[int] = None, eos_id: Optional[int] = None, - ) -> Tensor: + ) -> torch.Tensor: """ + Note: + If phone based lexicon is used, the following arguments are required: + + - supervision + - L_inv + - word_table + - oov_str + + If BPE based lexicon is used, the following arguments are required: + + - token_ids + - sos_id + - eos_id + Args: - x: Tensor of dimension (input_length, batch_size, d_model). - encoder_mask: Mask tensor of dimension (batch_size, input_length) - supervision: Supervison in lhotse format, get from batch['supervisions'] - graph_compiler: use graph_compiler.L_inv (Its labels are words, while its aux_labels are phones) - , graph_compiler.words and graph_compiler.oov - token_ids: A list of lists. Each list contains word piece IDs for an utterance. - sos_id: sos token id - eos_id: eos token id + memory: + It's the output of the encoder with shape [T, N, C] + memory_key_padding_mask: + The padding mask from the encoder. + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + L_inv: + It is an FSA with labels being word IDs and aux_labels being + token IDs (e.g., phone IDs or word piece IDs). + word_table: + Word table providing mapping between words and IDs. + oov_str: + The OOV word, e.g., '' + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id Returns: - Tensor: Decoder loss. + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. """ - if supervision is not None and graph_compiler is not None: + if supervision is not None and word_table is not None: batch_text = get_normal_transcripts( - supervision, graph_compiler.lexicon.words, graph_compiler.oov + supervision, word_table, oov_str ) ys_in_pad, ys_out_pad = add_sos_eos( batch_text, - graph_compiler.L_inv, + L_inv, sos_id, eos_id, ) @@ -227,31 +305,31 @@ class Transformer(nn.Module): ] ys_in_pad = pad_list(ys_in, eos_id) ys_out_pad = pad_list(ys_out, -1) - else: raise ValueError("Invalid input for decoder self attention") - ys_in_pad = ys_in_pad.to(x.device) - ys_out_pad = ys_out_pad.to(x.device) + device = memory.device + ys_in_pad = ys_in_pad.to(device) + ys_out_pad = ys_out_pad.to(device) tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - x.device + device ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) - tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) pred_pad = self.decoder( tgt=tgt, - memory=x, + memory=memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=encoder_mask, - ) # (T, B, F) - pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) - pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) @@ -259,23 +337,31 @@ class Transformer(nn.Module): def decoder_nll( self, - x: Tensor, - encoder_mask: Tensor, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, token_ids: List[List[int]], sos_id: int, eos_id: int, - ) -> Tensor: + ) -> torch.Tensor: """ Args: - x: encoder-output, Tensor of dimension (input_length, batch_size, d_model). - encoder_mask: Mask tensor of dimension (batch_size, input_length) - token_ids: n-best list extracted from lattice before rescore - + memory: + It's the output of the encoder with shape [T, N, C] + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs (e.g., word piece IDs). + Each sublist represents an utterance. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. Returns: - Tensor: negative log-likelihood. + A 2-D tensor of shape (len(token_ids), max_token_length) + representing the cross entropy loss (i.e., negative log-likelihood). """ - # The common part between this fuction and decoder_forward could be - # extracted as a seperated function. + # The common part between this function and decoder_forward could be + # extracted as a separate function. if token_ids is not None: _sos = torch.tensor([sos_id]) _eos = torch.tensor([eos_id]) @@ -290,11 +376,12 @@ class Transformer(nn.Module): else: raise ValueError("Invalid input for decoder self attention") - ys_in_pad = ys_in_pad.to(x.device, dtype=torch.int64) - ys_out_pad = ys_out_pad.to(x.device, dtype=torch.int64) + device = memory.device + ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - x.device + device ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) @@ -304,10 +391,10 @@ class Transformer(nn.Module): tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) pred_pad = self.decoder( tgt=tgt, - memory=x, + memory=memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=encoder_mask, + memory_key_padding_mask=memory_key_padding_mask, ) # (T, B, F) pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) @@ -326,16 +413,24 @@ class Transformer(nn.Module): class TransformerEncoderLayer(nn.Module): """ - Modified from torch.nn.TransformerEncoderLayer. Add support of normalize_before, + Modified from torch.nn.TransformerEncoderLayer. + Add support of normalize_before, i.e., use layer_norm before the first block. Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - activation: the activation function of intermediate layer, relu or gelu (default=relu). - normalize_before: whether to use layer_norm before the first block. + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + normalize_before: + whether to use layer_norm before the first block. Examples:: >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) @@ -375,23 +470,24 @@ class TransformerEncoderLayer(nn.Module): def forward( self, - src: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """ Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). + src_key_padding_mask: the mask for the src keys per batch (optional) Shape: src: (S, N, E). src_mask: (S, S). src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number """ residual = src if self.normalize_before: @@ -419,15 +515,22 @@ class TransformerEncoderLayer(nn.Module): class TransformerDecoderLayer(nn.Module): """ - Modified from torch.nn.TransformerDecoderLayer. Add support of normalize_before, + Modified from torch.nn.TransformerDecoderLayer. + Add support of normalize_before, i.e., use layer_norm before the first block. Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - activation: the activation function of intermediate layer, relu or gelu (default=relu). + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). Examples:: >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) @@ -471,22 +574,28 @@ class TransformerDecoderLayer(nn.Module): def forward( self, - tgt: Tensor, - memory: Tensor, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. Args: - tgt: the sequence to the decoder layer (required). - memory: the sequence from the last layer of the encoder (required). - tgt_mask: the mask for the tgt sequence (optional). - memory_mask: the mask for the memory sequence (optional). - tgt_key_padding_mask: the mask for the tgt keys per batch (optional). - memory_key_padding_mask: the mask for the memory keys per batch (optional). + tgt: + the sequence to the decoder layer (required). + memory: + the sequence from the last layer of the encoder (required). + tgt_mask: + the mask for the tgt sequence (optional). + memory_mask: + the mask for the memory sequence (optional). + tgt_key_padding_mask: + the mask for the tgt keys per batch (optional). + memory_key_padding_mask: + the mask for the memory keys per batch (optional). Shape: tgt: (T, N, E). @@ -495,7 +604,8 @@ class TransformerDecoderLayer(nn.Module): memory_mask: (T, S). tgt_key_padding_mask: (N, T). memory_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number """ residual = tgt if self.normalize_before: @@ -546,164 +656,55 @@ def _get_activation_fn(activation: str): ) -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py - - Args: - idim: Input dimension. - odim: Output dimension. - - """ - - def __init__(self, idim: int, odim: int) -> None: - """Construct a Conv2dSubsampling object.""" - super(Conv2dSubsampling, self).__init__() - self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), - nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), - nn.ReLU(), - ) - self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) - - def forward(self, x: Tensor) -> Tensor: - """Subsample x. - - Args: - x: Input tensor of dimension (batch_size, input_length, num_features). (#batch, time, idim). - - Returns: - torch.Tensor: Subsampled tensor of dimension (batch_size, input_length, d_model). - where time' = time // 4. - - """ - x = x.unsqueeze(1) # (b, c, t, f) - x = self.conv(x) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - return x - - -class VggSubsampling(nn.Module): - """Trying to follow the setup described here https://arxiv.org/pdf/1910.09799.pdf - This paper is not 100% explicit so I am guessing to some extent, - and trying to compare with other VGG implementations. - - Args: - idim: Input dimension. - odim: Output dimension. - - """ - - def __init__(self, idim: int, odim: int) -> None: - """Construct a VggSubsampling object. This uses 2 VGG blocks with 2 - Conv2d layers each, subsampling its input by a factor of 4 in the - time dimensions. - - Args: - idim: Number of features at input, e.g. 40 or 80 for MFCC - (will be treated as the image height). - odim: Output dimension (number of features), e.g. 256 - """ - super(VggSubsampling, self).__init__() - - cur_channels = 1 - layers = [] - block_dims = [32, 64] - - # The decision to use padding=1 for the 1st convolution, then padding=0 - # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by - # a back-compatibility concern so that the number of frames at the - # output would be equal to: - # (((T-1)//2)-1)//2. - # We can consider changing this by using padding=1 on the 2nd convolution, - # so the num-frames at the output would be T//4. - for block_dim in block_dims: - layers.append( - torch.nn.Conv2d( - in_channels=cur_channels, - out_channels=block_dim, - kernel_size=3, - padding=1, - stride=1, - ) - ) - layers.append(torch.nn.ReLU()) - layers.append( - torch.nn.Conv2d( - in_channels=block_dim, - out_channels=block_dim, - kernel_size=3, - padding=0, - stride=1, - ) - ) - layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) - ) - cur_channels = block_dim - - self.layers = nn.Sequential(*layers) - - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) - - def forward(self, x: Tensor) -> Tensor: - """Subsample x. - - Args: - x: Input tensor of dimension (batch_size, input_length, num_features). (#batch, time, idim). - - Returns: - torch.Tensor: Subsampled tensor of dimension (batch_size, input_length', d_model). - where input_length' == (((input_length - 1) // 2) - 1) // 2 - - """ - x = x.unsqueeze(1) # (b, c, t, f) - x = self.layers(x) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - return x - - class PositionalEncoding(nn.Module): - """ - Positional encoding. + """This class implements the positional encoding + proposed in the following paper: - Args: - d_model: Embedding dimension. - dropout: Dropout rate. - max_len: Maximum input length. + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note:: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) """ - def __init__( - self, d_model: int, dropout: float = 0.1, max_len: int = 5000 - ) -> None: - """Construct an PositionalEncoding object.""" - super(PositionalEncoding, self).__init__() + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. + """ + super().__init__() self.d_model = d_model self.xscale = math.sqrt(self.d_model) self.dropout = nn.Dropout(p=dropout) self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - def extend_pe(self, x: Tensor) -> None: - """Reset the positional encodings.""" + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is [1, T1, d_model]. The shape of the input x + is [N, T, d_model]. If T > T1, then we change the shape of self.pe + to [N, T, d_model]. Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape [N, T, C]. + Returns: + Return None. + """ if self.pe is not None: if self.pe.size(1) >= x.size(1): if self.pe.dtype != x.dtype or self.pe.device != x.device: self.pe = self.pe.to(dtype=x.dtype, device=x.device) return - pe = torch.zeros(x.size(1), self.d_model) + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, self.d_model, 2, dtype=torch.float32) @@ -712,34 +713,44 @@ class PositionalEncoding(nn.Module): pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) + # Now pe is of shape [1, T, d_model], where T is x.size(1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Add positional encoding. Args: - x: Input tensor of dimention (batch_size, input_length, d_model). + x: + Its shape is [N, T, C] Returns: - torch.Tensor: Encoded tensor of dimention (batch_size, input_length, d_model). - + Return a tensor of shape [N, T, C] """ self.extend_pe(x) - x = x * self.xscale + self.pe[:, : x.size(1)] + x = x * self.xscale + self.pe[:, : x.size(1), :] return self.dropout(x) class Noam(object): """ - Implements Noam optimizer. Proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa Args: - params (iterable): iterable of parameters to optimize or dicts defining parameter groups - model_size: attention dimension of the transformer model - factor: learning rate factor - warm_step: warmup steps + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps """ def __init__( @@ -812,7 +823,8 @@ class LabelSmoothingLoss(nn.Module): """ Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w) and p_{prob. computed by model}(w) is minimized. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa Args: size: the number of class @@ -841,19 +853,23 @@ class LabelSmoothingLoss(nn.Module): self.true_dist = None self.normalize_length = normalize_length - def forward(self, x: Tensor, target: Tensor) -> Tensor: + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Compute loss between x and target. Args: - x: prediction of dimention (batch_size, input_length, number_of_classes). - target: target masked with self.padding_id of dimention (batch_size, input_length). + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.padding_id of + dimension (batch_size, input_length). Returns: - torch.Tensor: scalar float value + A scalar tensor containing the loss without normalization. """ assert x.size(2) == self.size - batch_size = x.size(0) + # batch_size = x.size(0) x = x.view(-1, self.size) target = target.view(-1) with torch.no_grad(): @@ -871,12 +887,23 @@ class LabelSmoothingLoss(nn.Module): def encoder_padding_mask( max_len: int, supervisions: Optional[Supervisions] = None -) -> Optional[Tensor]: - """Make mask tensor containing indices of padded part. +) -> Optional[torch.Tensor]: + """Make mask tensor containing indexes of padded part. + + TODO:: + This function **assumes** that the model uses + a subsampling factor of 4. We should remove that + assumption later. Args: - max_len: maximum length of input features - supervisions : Supervison in lhotse format, i.e., batch['supervisions'] + max_len: + Maximum length of input features. + CAUTION: It is the length after subsampling. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) Returns: Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices. @@ -916,16 +943,23 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: Tensor, ignore_id: int = -1) -> Tensor: - """Generate a length mask for input. The masked position are filled with bool(True), - Unmasked positions are filled with bool(False). +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: + """Generate a length mask for input. + + The masked position are filled with bool(True), + Unmasked positions are filled with bool(False). Args: - ys_pad: padded tensor of dimension (batch_size, input_length). - ignore_id: the ignored number (the padding number) in ys_pad + ys_pad: + padded tensor of dimension (batch_size, input_length). + ignore_id: + the ignored number (the padding number) in ys_pad Returns: - Tensor: a mask tensor of dimension (batch_size, input_length). + Tensor: + a bool tensor of the same shape as the input tensor. """ ys_mask = ys_pad == ignore_id return ys_mask @@ -934,13 +968,20 @@ def decoder_padding_mask(ys_pad: Tensor, ignore_id: int = -1) -> Tensor: def get_normal_transcripts( supervision: Supervisions, words: k2.SymbolTable, oov: str = "" ) -> List[List[int]]: - """Get normal transcripts (1 input recording has 1 transcript) from lhotse cut format. - Achieved by concatenate the transcripts corresponding to the same recording. + """Get normal transcripts (1 input recording has 1 transcript) + from lhotse cut format. + + Achieved by concatenating the transcripts corresponding to the + same recording. Args: - supervision : Supervison in lhotse format, i.e., batch['supervisions'] - words: The word symbol table. - oov: Out of vocabulary word. + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + words: + The word symbol table. + oov: + Out of vocabulary word. Returns: List[List[int]]: List of concatenated transcripts, length is batch_size @@ -960,15 +1001,15 @@ def get_normal_transcripts( return batch_text -def generate_square_subsequent_mask(sz: int) -> Tensor: - """Generate a square mask for the sequence. The masked positions are filled with float('-inf'). - Unmasked positions are filled with float(0.0). +def generate_square_subsequent_mask(sz: int) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). Args: - sz: mask size + sz: mask size Returns: - Tensor: a square mask of dimension (sz, sz) + A square mask of dimension (sz, sz) """ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = ( @@ -981,39 +1022,49 @@ def generate_square_subsequent_mask(sz: int) -> Tensor: def add_sos_eos( ys: List[List[int]], - lexicon: k2.Fsa, + L_inv: k2.Fsa, sos_id: int, eos_id: int, ignore_id: int = -1, -) -> Tuple[Tensor, Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """Add and labels. Args: - ys: batch of unpadded target sequences - lexicon: Its labels are words, while its aux_labels are phones. - sos_id: index of - eos_id: index of - ignore_id: index of padding + ys: + Batch of unpadded target sequences (i.e., word IDs) + L_inv: + Its labels are words, while its aux_labels are tokens. + sos_id: + index of + eos_id: + index of + ignore_id: + value for padding Returns: - Tensor: Input of transformer decoder. Padded tensor of dimention (batch_size, max_length). - Tensor: Output of transformer decoder. padded tensor of dimention (batch_size, max_length). + Return a tuple containing two tensors: + - Input of transformer decoder. + Padded tensor of dimension (batch_size, max_length). + - Output of transformer decoder. + Padded tensor of dimension (batch_size, max_length). """ _sos = torch.tensor([sos_id]) _eos = torch.tensor([eos_id]) - ys = get_hierarchical_targets(ys, lexicon) + ys = get_hierarchical_targets(ys, L_inv) ys_in = [torch.cat([_sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, _eos], dim=0) for y in ys] - return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) + return pad_list(ys_in, eos_id), pad_list(ys_out, ignore_id) -def pad_list(ys: List[Tensor], pad_value: float) -> Tensor: +def pad_list(ys: List[torch.Tensor], pad_value: float) -> torch.Tensor: """Perform padding for the list of tensors. Args: - ys: List of tensors. len(ys) = batch_size. - pad_value: Value for padding. + ys: + List of tensors. len(ys) = batch_size. + pad_value: + Value for padding. Returns: Tensor: Padded tensor (batch_size, max_length, `*`). @@ -1039,25 +1090,25 @@ def pad_list(ys: List[Tensor], pad_value: float) -> Tensor: def get_hierarchical_targets( - ys: List[List[int]], lexicon: k2.Fsa -) -> List[Tensor]: - """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts). + ys: List[List[int]], L_inv: Optional[k2.Fsa] = None +) -> List[torch.Tensor]: + """Get hierarchical transcripts (i.e., phone level transcripts) from + transcripts (i.e., word level transcripts). Args: - ys: Word level transcripts. - lexicon: Its labels are words, while its aux_labels are phones. + ys: + Word level transcripts. Each sublist is a transcript of an utterance. + L_inv: + Its labels are words, while its aux_labels are tokens. Returns: - List[Tensor]: Phone level transcripts. - + List[torch.Tensor]: + Token level transcripts. """ - if lexicon is None: - return ys - else: - L_inv = lexicon + if L_inv is None: + return [torch.tensor(y) for y in ys] - n_batch = len(ys) device = L_inv.device transcripts = k2.create_fsa_vec( @@ -1081,19 +1132,3 @@ def get_hierarchical_targets( ys = [torch.tensor(y) for y in ys] return ys - - -def test_transformer(): - t = Transformer(40, 1281) - T = 200 - f = torch.rand(31, 40, T) - g, _, _ = t(f) - assert g.shape == (31, 1281, (((T - 1) // 2) - 1) // 2) - - -def main(): - test_transformer() - - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 0c07aaa1a..d81096070 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -11,11 +11,18 @@ import logging import os from pathlib import Path +import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor +# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and +# slow things down. Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + def compute_fbank_librispeech(): src_dir = Path("data/manifests") @@ -46,8 +53,7 @@ def compute_fbank_librispeech(): continue logging.info(f"Processing {partition}") cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], + recordings=m["recordings"], supervisions=m["supervisions"], ) if "train" in partition: cut_set = ( diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 6a46e6978..0fc515d8c 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -11,11 +11,18 @@ import logging import os from pathlib import Path +import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor +# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and +# slow things down. Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + def compute_fbank_musan(): src_dir = Path("data/manifests") diff --git a/icefall/decode.py b/icefall/decode.py index ed08405fa..0e9baf2e4 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -555,24 +555,31 @@ def rescore_with_attention_decoder( model: nn.Module, memory: torch.Tensor, memory_key_padding_mask: torch.Tensor, + sos_id: int, + eos_id: int, ) -> Dict[str, k2.Fsa]: """This function extracts n paths from the given lattice and uses an attention decoder to rescore them. The path with the highest score is used as the decoding output. - lattice: - An FsaVec. It can be the return value of :func:`get_lattice`. - num_paths: - Number of paths to extract from the given lattice for rescoring. - model: - A transformer model. See the class "Transformer" in - conformer_ctc/transformer.py for its interface. - memory: - The encoder memory of the given model. It is the output of - the last torch.nn.TransformerEncoder layer in the given model. - Its shape is `[T, N, C]`. - memory_key_padding_mask: - The padding mask for memory with shape [N, T]. + Args: + lattice: + An FsaVec. It can be the return value of :func:`get_lattice`. + num_paths: + Number of paths to extract from the given lattice for rescoring. + model: + A transformer model. See the class "Transformer" in + conformer_ctc/transformer.py for its interface. + memory: + The encoder memory of the given model. It is the output of + the last torch.nn.TransformerEncoder layer in the given model. + Its shape is `[T, N, C]`. + memory_key_padding_mask: + The padding mask for memory with shape [N, T]. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. Returns: A dict of FsaVec, whose key contains a string ngram_lm_scale_attention_scale and the value is the @@ -661,7 +668,11 @@ def rescore_with_attention_decoder( # TODO: pass the sos_token_id and eos_token_id via function arguments nll = model.decoder_nll( - expanded_memory, expanded_memory_key_padding_mask, token_ids, 1, 1 + memory=expanded_memory, + memory_key_padding_mask=expanded_memory_key_padding_mask, + token_ids=token_ids, + sos_id=sos_id, + eos_id=eos_id, ) assert nll.ndim == 2 assert nll.shape[0] == num_word_seqs