diff --git a/egs/librispeech/ASR/transducer/conformer.py b/egs/librispeech/ASR/transducer/conformer.py index b19b94db1..22977b835 100644 --- a/egs/librispeech/ASR/transducer/conformer.py +++ b/egs/librispeech/ASR/transducer/conformer.py @@ -22,20 +22,21 @@ from typing import Optional, Tuple import torch from torch import Tensor, nn -from transformer import Supervisions, Transformer, encoder_padding_mask +from transducer.transformer import Transformer + +from icefall.utils import make_pad_mask class Conformer(Transformer): """ Args: num_features (int): Number of input features - num_classes (int): Number of output classes + output_dim (int): Number of output dimension subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension nhead (int): number of head dim_feedforward (int): feedforward dimention num_encoder_layers (int): number of encoder layers - num_decoder_layers (int): number of decoder layers dropout (float): dropout rate cnn_module_kernel (int): Kernel size of convolution module normalize_before (bool): whether to use layer_norm before the first block. @@ -45,13 +46,12 @@ class Conformer(Transformer): def __init__( self, num_features: int, - num_classes: int, + output_dim: int, subsampling_factor: int = 4, d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048, num_encoder_layers: int = 12, - num_decoder_layers: int = 6, dropout: float = 0.1, cnn_module_kernel: int = 31, normalize_before: bool = True, @@ -60,13 +60,12 @@ class Conformer(Transformer): ) -> None: super(Conformer, self).__init__( num_features=num_features, - num_classes=num_classes, + output_dim=output_dim, subsampling_factor=subsampling_factor, d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, num_encoder_layers=num_encoder_layers, - num_decoder_layers=num_decoder_layers, dropout=dropout, normalize_before=normalize_before, vgg_frontend=vgg_frontend, @@ -92,38 +91,45 @@ class Conformer(Transformer): # and throws an error without this change. self.after_norm = identity - def run_encoder( - self, x: Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[Tensor, Optional[Tensor]]: + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x: - The model input. Its shape is (N, T, C). - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute encoder padding mask, which is used as memory key padding - mask for the decoder. - + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. Returns: - Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). - Tensor: Mask tensor of dimension (batch_size, input_length) + Return a tuple containing 2 tensors: + - logits, its shape is (batch_size, output_seq_len, output_dim) + - logit_lens, a tensor of shape (batch_size,) containing the number + of frames in `logits` before padding. """ + if self.use_feat_batchnorm: + x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) + x = self.feat_batchnorm(x) + x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) + x = self.encoder_embed(x) x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - mask = encoder_padding_mask(x.size(0), supervisions) - if mask is not None: - mask = mask.to(x.device) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == lengths.max().item() + mask = make_pad_mask(lengths) + + x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) if self.normalize_before: x = self.after_norm(x) - return x, mask + logits = self.encoder_output_layer(x) + logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return logits, lengths class ConformerEncoderLayer(nn.Module): diff --git a/egs/librispeech/ASR/transducer/encoder_interface.py b/egs/librispeech/ASR/transducer/encoder_interface.py index afb0bcfd1..257facce4 100644 --- a/egs/librispeech/ASR/transducer/encoder_interface.py +++ b/egs/librispeech/ASR/transducer/encoder_interface.py @@ -21,16 +21,6 @@ import torch.nn as nn class EncoderInterface(nn.Module): - def __init__(self, num_features: int, output_dim: int): - """ - Args: - num_features: - The dimension of the input features. - output_dim: - Output dimension of the model. - """ - super().__init__() - def forward( self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/egs/librispeech/ASR/transducer/test_conformer.py b/egs/librispeech/ASR/transducer/test_conformer.py new file mode 100755 index 000000000..98f7df78a --- /dev/null +++ b/egs/librispeech/ASR/transducer/test_conformer.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer/test_conformer.py +""" + +import torch +from transducer.conformer import Conformer + + +def test_conformer(): + output_dim = 1024 + conformer = Conformer( + num_features=80, + output_dim=output_dim, + subsampling_factor=4, + d_model=512, + nhead=8, + dim_feedforward=2048, + num_encoder_layers=12, + use_feat_batchnorm=True, + ) + N = 3 + T = 100 + C = 80 + x = torch.randn(N, T, C) + x_lens = torch.tensor([50, 100, 80]) + logits, logit_lens = conformer(x, x_lens) + + expected_T = ((T - 1) // 2 - 1) // 2 + assert logits.shape == (N, expected_T, output_dim) + assert logit_lens.max().item() == expected_T + print(logits.shape) + print(logit_lens) + + +def main(): + test_conformer() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer/test_decoder.py b/egs/librispeech/ASR/transducer/test_decoder.py index a883eb78f..eacadf5a3 100755 --- a/egs/librispeech/ASR/transducer/test_decoder.py +++ b/egs/librispeech/ASR/transducer/test_decoder.py @@ -18,7 +18,7 @@ """ To run this file, do: - cd icefall/egs/yesno/ASR + cd icefall/egs/librispeech/ASR python ./transducer/test_decoder.py """ diff --git a/egs/librispeech/ASR/transducer/test_joiner.py b/egs/librispeech/ASR/transducer/test_joiner.py index 2773ca319..b187c5ac6 100755 --- a/egs/librispeech/ASR/transducer/test_joiner.py +++ b/egs/librispeech/ASR/transducer/test_joiner.py @@ -18,7 +18,7 @@ """ To run this file, do: - cd icefall/egs/yesno/ASR + cd icefall/egs/librispeech/ASR python ./transducer/test_joiner.py """ diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py index c7d524f7d..d5adac482 100755 --- a/egs/librispeech/ASR/transducer/test_rnn.py +++ b/egs/librispeech/ASR/transducer/test_rnn.py @@ -15,6 +15,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer/test_rnn.py +""" import torch import torch.nn as nn from transducer.rnn import ( diff --git a/egs/librispeech/ASR/transducer/test_transducer.py b/egs/librispeech/ASR/transducer/test_transducer.py new file mode 100755 index 000000000..a65843e9b --- /dev/null +++ b/egs/librispeech/ASR/transducer/test_transducer.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer/test_transducer.py +""" + + +import k2 +import torch +from transducer.conformer import Conformer +from transducer.decoder import Decoder +from transducer.joiner import Joiner +from transducer.model import Transducer + + +def test_transducer(): + # encoder params + input_dim = 10 + output_dim = 20 + + # decoder params + vocab_size = 3 + blank_id = 0 + sos_id = 2 + embedding_dim = 128 + num_layers = 2 + + encoder = Conformer( + num_features=input_dim, + output_dim=output_dim, + subsampling_factor=4, + d_model=512, + nhead=8, + dim_feedforward=2048, + num_encoder_layers=12, + use_feat_batchnorm=True, + ) + + decoder = Decoder( + vocab_size=vocab_size, + embedding_dim=embedding_dim, + blank_id=blank_id, + sos_id=sos_id, + num_layers=num_layers, + hidden_dim=output_dim, + embedding_dropout=0.0, + rnn_dropout=0.0, + ) + + joiner = Joiner(output_dim, vocab_size) + transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) + + y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]]) + N = y.dim0 + T = 50 + + x = torch.rand(N, T, input_dim) + x_lens = torch.randint(low=30, high=T, size=(N,), dtype=torch.int32) + x_lens[0] = T + + loss = transducer(x, x_lens, y) + print(loss) + + +def main(): + test_transducer() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer/test_transformer.py b/egs/librispeech/ASR/transducer/test_transformer.py new file mode 100755 index 000000000..5e35d56a6 --- /dev/null +++ b/egs/librispeech/ASR/transducer/test_transformer.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer/test_transformer.py +""" + +import torch +from transducer.transformer import Transformer + + +def test_transformer(): + output_dim = 1024 + transformer = Transformer( + num_features=80, + output_dim=output_dim, + subsampling_factor=4, + d_model=512, + nhead=8, + dim_feedforward=2048, + num_encoder_layers=12, + use_feat_batchnorm=True, + ) + N = 3 + T = 100 + C = 80 + x = torch.randn(N, T, C) + x_lens = torch.tensor([50, 100, 80]) + logits, logit_lens = transformer(x, x_lens) + + expected_T = ((T - 1) // 2 - 1) // 2 + assert logits.shape == (N, expected_T, output_dim) + assert logit_lens.max().item() == expected_T + print(logits.shape) + print(logit_lens) + + +def main(): + test_transformer() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer/transformer.py b/egs/librispeech/ASR/transducer/transformer.py index f93914aaa..ffe0ba722 100644 --- a/egs/librispeech/ASR/transducer/transformer.py +++ b/egs/librispeech/ASR/transducer/transformer.py @@ -16,29 +16,26 @@ import math -from typing import Dict, List, Optional, Tuple +from typing import Optional, Tuple import torch import torch.nn as nn -from label_smoothing import LabelSmoothingLoss from subsampling import Conv2dSubsampling, VggSubsampling -from torch.nn.utils.rnn import pad_sequence +from transducer.encoder_interface import EncoderInterface -# Note: TorchScript requires Dict/List/etc. to be fully typed. -Supervisions = Dict[str, torch.Tensor] +from icefall.utils import make_pad_mask -class Transformer(nn.Module): +class Transformer(EncoderInterface): def __init__( self, num_features: int, - num_classes: int, + output_dim: int, subsampling_factor: int = 4, d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048, num_encoder_layers: int = 12, - num_decoder_layers: int = 6, dropout: float = 0.1, normalize_before: bool = True, vgg_frontend: bool = False, @@ -48,7 +45,7 @@ class Transformer(nn.Module): Args: num_features: The input dimension of the model. - num_classes: + output_dim: The output dimension of the model. subsampling_factor: Number of output frames is num_in_frames // subsampling_factor. @@ -59,13 +56,11 @@ class Transformer(nn.Module): Number of heads in multi-head attention. Must satisfy d_model // nhead == 0. dim_feedforward: - The output dimension of the feedforward layers in encoder/decoder. + The output dimension of the feedforward layers in encoder. num_encoder_layers: Number of encoder layers. - num_decoder_layers: - Number of decoder layers. dropout: - Dropout in encoder/decoder. + Dropout in encoder. normalize_before: If True, use pre-layer norm; False to use post-layer norm. vgg_frontend: @@ -79,16 +74,16 @@ class Transformer(nn.Module): self.feat_batchnorm = nn.BatchNorm1d(num_features) self.num_features = num_features - self.num_classes = num_classes + self.output_dim = output_dim self.subsampling_factor = subsampling_factor if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - # self.encoder_embed converts the input of shape (N, T, num_classes) + # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_classes -> d_model + # (2) embedding: num_features -> d_model if vgg_frontend: self.encoder_embed = VggSubsampling(num_features, d_model) else: @@ -117,279 +112,45 @@ class Transformer(nn.Module): # TODO(fangjun): remove dropout self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + nn.Dropout(p=dropout), nn.Linear(d_model, output_dim) ) - if num_decoder_layers > 0: - self.decoder_num_class = ( - self.num_classes - ) # bpe model already has sos/eos symbol - - self.decoder_embed = nn.Embedding( - num_embeddings=self.decoder_num_class, embedding_dim=d_model - ) - self.decoder_pos = PositionalEncoding(d_model, dropout) - - decoder_layer = TransformerDecoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - normalize_before=normalize_before, - ) - - if normalize_before: - decoder_norm = nn.LayerNorm(d_model) - else: - decoder_norm = None - - self.decoder = nn.TransformerDecoder( - decoder_layer=decoder_layer, - num_layers=num_decoder_layers, - norm=decoder_norm, - ) - - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) - - self.decoder_criterion = LabelSmoothingLoss() - else: - self.decoder_criterion = None - def forward( - self, x: torch.Tensor, supervision: Optional[Supervisions] = None - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x: - The input tensor. Its shape is (N, T, C). - supervision: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) - + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. Returns: - Return a tuple containing 3 tensors: - - CTC output for ctc decoding. Its shape is (N, T, C) - - Encoder output with shape (T, N, C). It can be used as key and - value for the decoder. - - Encoder output padding mask. It can be used as - memory_key_padding_mask for the decoder. Its shape is (N, T). - It is None if `supervision` is None. + Return a tuple containing 2 tensors: + - logits, its shape is (batch_size, output_seq_len, output_dim) + - logit_lens, a tensor of shape (batch_size,) containing the number + of frames in `logits` before padding. """ if self.use_feat_batchnorm: x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) - x = self.ctc_output(encoder_memory) - return x, encoder_memory, memory_key_padding_mask - def run_encoder( - self, x: torch.Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Run the transformer encoder. - - Args: - x: - The model input. Its shape is (N, T, C). - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute the encoder padding mask, which is used as memory key - padding mask for the decoder. - Returns: - Return a tuple with two tensors: - - The encoder output, with shape (T, N, C) - - encoder padding mask, with shape (N, T). - The mask is None if `supervisions` is None. - It is used as memory key padding mask in the decoder. - """ x = self.encoder_embed(x) x = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - mask = encoder_padding_mask(x.size(0), supervisions) - mask = mask.to(x.device) if mask is not None else None + + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == lengths.max().item() + + mask = make_pad_mask(lengths) x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) - return x, mask + logits = self.encoder_output_layer(x) + logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - def ctc_output(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - The output tensor from the transformer encoder. - Its shape is (T, N, C) - - Returns: - Return a tensor that can be used for CTC decoding. - Its shape is (N, T, C) - """ - x = self.encoder_output_layer(x) - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) - return x - - @torch.jit.export - def decoder_forward( - self, - memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[List[int]], - sos_id: int, - eos_id: int, - ) -> torch.Tensor: - """ - Args: - memory: - It's the output of the encoder with shape (T, N, C) - memory_key_padding_mask: - The padding mask from the encoder. - token_ids: - A list-of-list IDs. Each sublist contains IDs for an utterance. - The IDs can be either phone IDs or word piece IDs. - sos_id: - sos token id - eos_id: - eos token id - - Returns: - A scalar, the **sum** of label smoothing loss over utterances - in the batch without any normalization. - """ - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) - - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) - - device = memory.device - ys_in_pad = ys_in_pad.to(device) - ys_out_pad = ys_out_pad.to(device) - - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) - - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) # (T, N, C) - pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) - - decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) - - return decoder_loss - - @torch.jit.export - def decoder_nll( - self, - memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[torch.Tensor], - sos_id: int, - eos_id: int, - ) -> torch.Tensor: - """ - Args: - memory: - It's the output of the encoder with shape (T, N, C) - memory_key_padding_mask: - The padding mask from the encoder. - token_ids: - A list-of-list IDs (e.g., word piece IDs). - Each sublist represents an utterance. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - Returns: - A 2-D tensor of shape (len(token_ids), max_token_length) - representing the cross entropy loss (i.e., negative log-likelihood). - """ - # The common part between this function and decoder_forward could be - # extracted as a separate function. - if isinstance(token_ids[0], torch.Tensor): - # This branch is executed by torchscript in C++. - # See https://github.com/k2-fsa/k2/pull/870 - # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 - token_ids = [tolist(t) for t in token_ids] - - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) - - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) - - device = memory.device - ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) - ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) - - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) # (T, B, F) - pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) - pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) - # nll: negative log-likelihood - nll = torch.nn.functional.cross_entropy( - pred_pad.view(-1, self.decoder_num_class), - ys_out_pad.view(-1), - ignore_index=-1, - reduction="none", - ) - - nll = nll.view(pred_pad.shape[0], -1) - - return nll + return logits, lengths class TransformerEncoderLayer(nn.Module): @@ -494,138 +255,6 @@ class TransformerEncoderLayer(nn.Module): return src -class TransformerDecoderLayer(nn.Module): - """ - Modified from torch.nn.TransformerDecoderLayer. - Add support of normalize_before, - i.e., use layer_norm before the first block. - - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> out = decoder_layer(tgt, memory) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - normalize_before: bool = True, - ) -> None: - super(TransformerDecoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.dropout3 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - self.normalize_before = normalize_before - - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = nn.functional.relu - super(TransformerDecoderLayer, self).__setstate__(state) - - def forward( - self, - tgt: torch.Tensor, - memory: torch.Tensor, - tgt_mask: Optional[torch.Tensor] = None, - memory_mask: Optional[torch.Tensor] = None, - tgt_key_padding_mask: Optional[torch.Tensor] = None, - memory_key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Pass the inputs (and mask) through the decoder layer. - - Args: - tgt: - the sequence to the decoder layer (required). - memory: - the sequence from the last layer of the encoder (required). - tgt_mask: - the mask for the tgt sequence (optional). - memory_mask: - the mask for the memory sequence (optional). - tgt_key_padding_mask: - the mask for the tgt keys per batch (optional). - memory_key_padding_mask: - the mask for the memory keys per batch (optional). - - Shape: - tgt: (T, N, E). - memory: (S, N, E). - tgt_mask: (T, T). - memory_mask: (T, S). - tgt_key_padding_mask: (N, T). - memory_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, - N is the batch size, E is the feature number - """ - residual = tgt - if self.normalize_before: - tgt = self.norm1(tgt) - tgt2 = self.self_attn( - tgt, - tgt, - tgt, - attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask, - )[0] - tgt = residual + self.dropout1(tgt2) - if not self.normalize_before: - tgt = self.norm1(tgt) - - residual = tgt - if self.normalize_before: - tgt = self.norm2(tgt) - tgt2 = self.src_attn( - tgt, - memory, - memory, - attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, - )[0] - tgt = residual + self.dropout2(tgt2) - if not self.normalize_before: - tgt = self.norm2(tgt) - - residual = tgt - if self.normalize_before: - tgt = self.norm3(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) - tgt = residual + self.dropout3(tgt2) - if not self.normalize_before: - tgt = self.norm3(tgt) - return tgt - - def _get_activation_fn(activation: str): if activation == "relu": return nn.functional.relu @@ -798,149 +427,3 @@ class Noam(object): self.optimizer.load_state_dict(state_dict["optimizer"]) else: setattr(self, key, value) - - -def encoder_padding_mask( - max_len: int, supervisions: Optional[Supervisions] = None -) -> Optional[torch.Tensor]: - """Make mask tensor containing indexes of padded part. - - TODO:: - This function **assumes** that the model uses - a subsampling factor of 4. We should remove that - assumption later. - - Args: - max_len: - Maximum length of input features. - CAUTION: It is the length after subsampling. - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) - - Returns: - Tensor: Mask tensor of dimension (batch_size, input_length), - True denote the masked indices. - """ - if supervisions is None: - return None - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"], - supervisions["num_frames"], - ), - 1, - ).to(torch.int32) - - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] - for idx in range(supervision_segments.size(0)): - # Note: TorchScript doesn't allow to unpack tensors as tuples - sequence_idx = supervision_segments[idx, 0].item() - start_frame = supervision_segments[idx, 1].item() - num_frames = supervision_segments[idx, 2].item() - lengths[sequence_idx] = start_frame + num_frames - - lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] - bs = int(len(lengths)) - seq_range = torch.arange(0, max_len, dtype=torch.int64) - seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) - # Note: TorchScript doesn't implement Tensor.new() - seq_length_expand = torch.tensor( - lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype - ).unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand - - return mask - - -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: - """Generate a length mask for input. - - The masked position are filled with True, - Unmasked positions are filled with False. - - Args: - ys_pad: - padded tensor of dimension (batch_size, input_length). - ignore_id: - the ignored number (the padding number) in ys_pad - - Returns: - Tensor: - a bool tensor of the same shape as the input tensor. - """ - ys_mask = ys_pad == ignore_id - return ys_mask - - -def generate_square_subsequent_mask(sz: int) -> torch.Tensor: - """Generate a square mask for the sequence. The masked positions are - filled with float('-inf'). Unmasked positions are filled with float(0.0). - The mask can be used for masked self-attention. - - For instance, if sz is 3, it returns:: - - tensor([[0., -inf, -inf], - [0., 0., -inf], - [0., 0., 0]]) - - Args: - sz: mask size - - Returns: - A square mask of dimension (sz, sz) - """ - mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) - mask = ( - mask.float() - .masked_fill(mask == 0, float("-inf")) - .masked_fill(mask == 1, float(0.0)) - ) - return mask - - -def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: - """Prepend sos_id to each utterance. - - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - sos_id: - The ID of the SOS token. - - Return: - Return a new list-of-list, where each sublist starts - with SOS ID. - """ - return [[sos_id] + utt for utt in token_ids] - - -def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: - """Append eos_id to each utterance. - - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - eos_id: - The ID of the EOS token. - - Return: - Return a new list-of-list, where each sublist ends - with EOS ID. - """ - return [utt + [eos_id] for utt in token_ids] - - -def tolist(t: torch.Tensor) -> List[int]: - """Used by jit""" - return torch.jit.annotate(List[int], t.tolist())