diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/README.md b/egs/librispeech/ASR/streaming_conformer_ctc/README.md new file mode 100644 index 000000000..01be7090b --- /dev/null +++ b/egs/librispeech/ASR/streaming_conformer_ctc/README.md @@ -0,0 +1,90 @@ +## Train and Decode +Commands of data preparation/train/decode steps are almost the same with +../conformer_ctc experiment except some options. + +Please read the code and understand following new added options before running this experiment: + + For data preparation: + + Nothing new. + + For streaming_conformer_ctc/train.py: + + --dynamic-chunk-training + --short-chunk-proportion + + For streaming_conformer_ctc/streaming_decode.py: + + --chunk-size + --tailing-num-frames + --simulate-streaming + +## Performence and a trained model. + +The latest results with this streaming code is shown in following table: + +chunk size | wer on test-clean | wer on test-other +-- | -- | -- +full | 3.53 | 8.52 +40(1.96s) | 3.78 | 9.38 +32(1.28s) | 3.82 | 9.44 +24(0.96s) | 3.95 | 9.76 +16(0.64s) | 4.06 | 9.98 +8(0.32s) | 4.30 | 10.55 +4(0.16s) | 5.88 | 12.01 + + +A trained model is also provided. +By run +``` +git clone https://huggingface.co/GuoLiyong/streaming_conformer + +# You may want to manually check md5sum values of downloaded files +# 8e633bc1de37f5ae57a2694ceee32a93 trained_streaming_conformer.pt +# 4c0aeefe26c784ec64873cc9b95420f1 L.pt +# d1f91d81005fb8ce4d65953a4a984ee7 Linv.pt +# e1c1902feb7b9fc69cd8d26e663c2608 bpe.model +# 8617e67159b0ff9118baa54f04db24cc tokens.txt +# 72b075ab5e851005cd854e666c82c3bb words.txt +``` + +If there is any different md5sum values, please run +``` +cd streaming_models +git lfs pull +``` +And check md5sum values again. + +Finally, following files will be downloaded: +
+streaming_models/  
+|-- lang_bpe  
+|   |-- L.pt  
+|   |-- Linv.pt  
+|   |-- bpe.model
+|   |-- tokens.txt
+|   `-- words.txt
+`-- trained_streaming_conformer.pt
+
+ + +And run commands you will get the same results of previous table: +``` +trained_models=/path/to/downloaded/streaming_models/ +for chunk_size in 4 8 16 24 36 40 -1; do + ./streaming_conformer_ctc/streaming_decode.py \ + --chunk-size=${chunk_size} \ + --trained-dir=${trained_models} +done +``` +Results of following command is indentical to previous one, +but model consumes features chunk_by_chunk, i.e. a streaming way. +``` +trained_models=/path/to/downloaded/streaming_models/ +for chunk_size in 4 8 16 24 36 40 -1; do + ./streaming_conformer_ctc/streaming_decode.py \ + --simulate-streaming=True \ + --chunk-size=${chunk_size} \ + --trained-dir=${trained_models} +done +``` diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/asr_datamodule.py b/egs/librispeech/ASR/streaming_conformer_ctc/asr_datamodule.py new file mode 120000 index 000000000..a73848de9 --- /dev/null +++ b/egs/librispeech/ASR/streaming_conformer_ctc/asr_datamodule.py @@ -0,0 +1 @@ +../conformer_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py new file mode 100644 index 000000000..33f8e2e15 --- /dev/null +++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py @@ -0,0 +1,1355 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +import warnings +from typing import Optional, Tuple + +import torch +from torch import Tensor, nn +from transformer import Supervisions, Transformer, encoder_padding_mask + + +# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py#L42 +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + Returns: + torch.Tensor: mask + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + ret = torch.zeros(size, size, device=device, dtype=torch.bool) + for i in range(size): + if num_left_chunks < 0: + start = 0 + else: + start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) + ending = min((i // chunk_size + 1) * chunk_size, size) + ret[i, start:ending] = True + return ret + + +class Conformer(Transformer): + """ + Args: + num_features (int): Number of input features + num_classes (int): Number of output classes + subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) + d_model (int): attention dimension + nhead (int): number of head + dim_feedforward (int): feedforward dimention + num_encoder_layers (int): number of encoder layers + num_decoder_layers (int): number of decoder layers + dropout (float): dropout rate + cnn_module_kernel (int): Kernel size of convolution module + normalize_before (bool): whether to use layer_norm before the first block. + vgg_frontend (bool): whether to use vgg frontend. + """ + + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + vgg_frontend: bool = False, + use_feat_batchnorm: bool = False, + causal: bool = False, + ) -> None: + super(Conformer, self).__init__( + num_features=num_features, + num_classes=num_classes, + subsampling_factor=subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dropout=dropout, + normalize_before=normalize_before, + vgg_frontend=vgg_frontend, + use_feat_batchnorm=use_feat_batchnorm, + ) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + cnn_module_kernel, + normalize_before, + causal, + ) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.normalize_before = normalize_before + if self.normalize_before: + self.after_norm = nn.LayerNorm(d_model) + else: + # Note: TorchScript detects that self.after_norm could be used inside forward() + # and throws an error without this change. + self.after_norm = identity + + def run_encoder( + self, + x: Tensor, + supervisions: Optional[Supervisions] = None, + dynamic_chunk_training: bool = False, + short_chunk_proportion: float = 0.5, + chunk_size: int = -1, + simulate_streaming: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. + dynamic_chunk_training: + For training only. + IF True, train with dynamic right context for some batches + sampled with a distribution + if False, train with full right context all the time. + short_chunk_proportion: + For training only. + Proportion of samples that will be trained with dynamic chunk. + chunk_size: + For eval only. + right context when evaluating test utts. + -1 means all right context. + simulate_streaming=False, + For eval only. + If true, the feature will be feeded into the model chunk by chunk. + If false, the whole utts if feeded into the model together i.e. the + model only foward once. + + + Returns: + Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). + Tensor: Mask tensor of dimension (batch_size, input_length) + """ + if self.encoder.training: + return self.train_run_encoder( + x, supervisions, dynamic_chunk_training, short_chunk_proportion + ) + else: + return self.eval_run_encoder( + x, supervisions, chunk_size, simulate_streaming + ) + + def train_run_encoder( + self, + x: Tensor, + supervisions: Optional[Supervisions] = None, + dynamic_chunk_training: bool = False, + short_chunk_threshold: float = 0.5, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. + dynamic_chunk_training: + IF True, train with dynamic right context for some batches + sampled with a distribution + if False, train with full right context all the time. + short_chunk_proportion: + Proportion of samples that will be trained with dynamic chunk. + """ + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + src_key_padding_mask = encoder_padding_mask(x.size(0), supervisions) + if src_key_padding_mask is not None: + src_key_padding_mask = src_key_padding_mask.to(x.device) + + if dynamic_chunk_training: + max_len = x.size(0) + chunk_size = torch.randint(1, max_len, (1,)).item() + if chunk_size > (max_len * short_chunk_threshold): + chunk_size = max_len + else: + chunk_size = chunk_size % 25 + 1 + mask = ~subsequent_chunk_mask( + size=x.size(0), chunk_size=chunk_size, device=x.device + ) + x = self.encoder( + x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask + ) # (T, B, F) + else: + x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + + if self.normalize_before: + x = self.after_norm(x) + + return x, src_key_padding_mask + + def eval_run_encoder( + self, + feature: Tensor, + supervisions: Optional[Supervisions] = None, + chunk_size: int = -1, + simulate_streaming=False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + feature: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. + + Returns: + Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). + Tensor: Mask tensor of dimension (batch_size, input_length) + """ + # feature.shape: N T C + num_frames = feature.size(1) + + # As temporarily in icefall only subsampling_rate == 4 is supported, + # following parameters are hard-coded here. + # Change it accordingly if other subsamling_rate are supported. + embed_left_context = 7 + embed_conv_right_context = 3 + subsampling_rate = 4 + stride = chunk_size * subsampling_rate + decoding_window = embed_conv_right_context + stride + + # This is also only compatible to sumsampling_rate == 4 + length_after_subsampling = ((feature.size(1) - 1) // 2 - 1) // 2 + src_key_padding_mask = encoder_padding_mask( + length_after_subsampling, supervisions + ) + if src_key_padding_mask is not None: + src_key_padding_mask = src_key_padding_mask.to(feature.device) + + if chunk_size < 0: + # non-streaming decoding + x = self.encoder_embed(feature) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + x = self.encoder( + x, pos_emb, src_key_padding_mask=src_key_padding_mask + ) # (T, B, F) + else: + if simulate_streaming: + # simulate chunk_by_chunk streaming decoding + # Results of this branch should be identical to following + # "else" branch. + # But this branch is a little slower + # as the feature is feeded chunk by chunk + + # store the result of chunk_by_chunk decoding + encoder_output = [] + + # caches + pos_emb_positive = [] + pos_emb_negative = [] + pos_emb_central = None + encoder_cache = [None for i in range(len(self.encoder.layers))] + conv_cache = [None for i in range(len(self.encoder.layers))] + + # start chunk_by_chunk decoding + offset = 0 + for cur in range( + 0, num_frames - embed_left_context + 1, stride + ): + end = min(cur + decoding_window, num_frames) + cur_feature = feature[:, cur:end, :] + cur_feature = self.encoder_embed(cur_feature) + cur_embed, cur_pos_emb = self.encoder_pos( + cur_feature, offset + ) + cur_embed = cur_embed.permute( + 1, 0, 2 + ) # (B, T, F) -> (T, B, F) + + cur_T = cur_feature.size(1) + if cur == 0: + # for first chunk extract the central pos embedding + pos_emb_central = cur_pos_emb[ + 0, (chunk_size - 1), : + ].view(1, 1, -1) + cur_T -= 1 + pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0)) + pos_emb_negative.append(cur_pos_emb[0, -cur_T:]) + assert pos_emb_positive[-1].size(0) == cur_T + + pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze( + 0 + ) + pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze( + 0 + ) + cur_pos_emb = torch.cat( + [pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg], + dim=1, + ) + + x = self.encoder.chunk_forward( + cur_embed, + cur_pos_emb, + src_key_padding_mask=src_key_padding_mask[ + :, : offset + cur_embed.size(0) + ], + encoder_cache=encoder_cache, + conv_cache=conv_cache, + offset=offset, + ) # (T, B, F) + encoder_output.append(x) + offset += cur_embed.size(0) + + x = torch.cat(encoder_output, dim=0) + else: + # NOT simulate chunk_by_chunk decoding + # Results of this branch should be identical to previous + # simulate chunk_by_chunk decoding branch. + # But this branch is faster. + x = self.encoder_embed(feature) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + mask = ~subsequent_chunk_mask( + size=x.size(0), chunk_size=chunk_size, device=x.device + ) + x = self.encoder( + x, + pos_emb, + mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) # (T, B, F) + + if self.normalize_before: + x = self.after_norm(x) + + return x, src_key_padding_mask + + +class ConformerEncoderLayer(nn.Module): + """ + ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + normalize_before: whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + causal: bool = False, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) + + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + + self.ff_scale = 0.5 + + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block + + self.dropout = nn.Dropout(dropout) + + self.normalize_before = normalize_before + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + # macaron style feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) + if not self.normalize_before: + src = self.norm_ff_macaron(src) + + # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) + + # convolution module + residual = src + if self.normalize_before: + src = self.norm_conv(src) + src = residual + self.dropout(self.conv_module(src)) + if not self.normalize_before: + src = self.norm_conv(src) + + # feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) + + if self.normalize_before: + src = self.norm_final(src) + + return src + + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + encoder_cache: Optional[Tensor] = None, + conv_cache: Optional[Tensor] = None, + offset=0, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + # macaron style feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) + if not self.normalize_before: + src = self.norm_ff_macaron(src) + + # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) + if encoder_cache is None: + # src: [chunk_size, N, F] e.g. [8, 41, 512] + key = src + val = key + encoder_cache = key + else: + key = torch.cat([encoder_cache, src], dim=0) + val = key + encoder_cache = key + src_att = self.self_attn( + src, + key, + val, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + offset=offset, + )[0] + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) + + # convolution module + residual = src # [chunk_size, N, F] e.g. [8, 41, 512] + if self.normalize_before: + src = self.norm_conv(src) + if conv_cache is not None: + src = torch.cat([conv_cache, src], dim=0) + conv_cache = src + + src = self.conv_module(src) + src = src[-residual.size(0) :, :, :] # noqa: E203 + + src = residual + self.dropout(src) + if not self.normalize_before: + src = self.norm_conv(src) + + # feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) + + if self.normalize_before: + src = self.norm_final(src) + + return src, encoder_cache, conv_cache + + +class ConformerEncoder(nn.TransformerEncoder): + r"""ConformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = conformer_encoder(src, pos_emb) + """ + + def __init__( + self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None + ) -> None: + super(ConformerEncoder, self).__init__( + encoder_layer=encoder_layer, num_layers=num_layers, norm=norm + ) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for mod in self.layers: + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + encoder_cache=None, + conv_cache=None, + offset=0, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for layer_index, mod in enumerate(self.layers): + output, e_cache, c_cache = mod.chunk_forward( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + encoder_cache=encoder_cache[layer_index], + conv_cache=conv_cache[layer_index], + offset=offset, + ) + encoder_cache[layer_index] = e_cache + conv_cache[layer_index] = c_cache + + if self.norm is not None: + output = self.norm(output) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor, offset: int = 0) -> None: + """Reset the positional encodings.""" + x_size_1 = offset + x.size(1) + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x_size_1 * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x, offset) + x = x * self.xscale + x_size_1 = offset + x.size(1) + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x_size_1 + + 1 : self.pe.size(1) // 2 # noqa E203 + + x_size_1, + ] + x_T = x.size(1) + if offset > 0: + pos_emb = torch.cat([pos_emb[:, :x_T], pos_emb[:, -x_T:]], dim=1) + else: + pos_emb = torch.cat( + [ + pos_emb[:, : (x_T - 1)], + self.pe[0, self.pe.size(1) // 2].view( + 1, 1, self.pe.size(-1) + ), + pos_emb[:, -(x_T - 1) :], # noqa: E203 + ], + dim=1, + ) + + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + offset=0, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.weight, + self.in_proj.bias, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + offset=offset, + ) + + def rel_shift(self, x: Tensor, offset=0) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 == time1 + offset, since it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + time2 = time1 + offset + assert n == 2 * time2 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + + return x.as_strided( + (batch_size, num_heads, time1, time2), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + offset=0, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + + # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) + p = p.permute(0, 2, 3, 1) + + q_with_bias_u = (q + self.pos_bias_u).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self.pos_bias_v).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift( + matrix_bd, offset=offset + ) # [B, head, time1, time2] + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + bias: bool = True, + causal: bool = False, + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + # from https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/convolution.py#L41 + if causal: + self.lorder = kernel_size - 1 + padding = 0 # manualy padding self.lorder zeros to the left later + else: + assert (kernel_size - 1) % 2 == 0 + self.lorder = 0 + padding = (kernel_size - 1) // 2 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + self.norm = nn.BatchNorm1d(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward(self, x: Tensor) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + if self.lorder > 0: + # manualy padding self.lorder zeros to the left + # make depthwise_conv causal + x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def identity(x): + return x diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/streaming_conformer_ctc/label_smoothing.py new file mode 120000 index 000000000..08734abd7 --- /dev/null +++ b/egs/librispeech/ASR/streaming_conformer_ctc/label_smoothing.py @@ -0,0 +1 @@ +../conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py new file mode 100755 index 000000000..a74c51836 --- /dev/null +++ b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py @@ -0,0 +1,521 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py#L166 +def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: + new_hyp: List[int] = [] + cur = 0 + while cur < len(hyp): + if hyp[cur] != 0: + new_hyp.append(hyp[cur]) + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + cur += 1 + return new_hyp + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--chunk-size", + type=int, + default=8, + help="Frames of right context" + "-1 for whole right context, i.e. non-streaming decoding", + ) + + parser.add_argument( + "--tailing-num-frames", + type=int, + default=20, + help="tailing dummy frames padded to the right," + "only used during decoding", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="simulate chunk by chunk decoding", + ) + parser.add_argument( + "--method", + type=str, + default="ctc-greedy-search", + help="Streaming Decoding method", + ) + + parser.add_argument( + "--export", + type=str2bool, + default=False, + help="""When enabled, the averaged model is saved to + conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved. + pretrained.pt contains a dict {"model": model.state_dict()}, + which can be loaded by `icefall.checkpoint.load_checkpoint()`. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="streaming_conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--trained-dir", + type=Path, + default=None, + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe", + help="The lang dir", + ) + + parser.add_argument( + "--avg-models", + type=str, + default=None, + help="Manually select models to average, seperated by comma;" + "e.g. 60,62,63,72", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "exp_dir": Path("conformer_ctc/exp"), + "lang_dir": Path("data/lang_bpe"), + # parameters for conformer + "causal": True, + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + chunk_size: int = -1, + simulate_streaming=False, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + model: + The neural model. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + feature = batch["inputs"] + device = torch.device("cuda") + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + # Extra dummy tailing frames my reduce deletion error + # example WITHOUT padding: + # CHAPTER SEVEN ON THE RACES OF MAN + # example WITH padding: + # CHAPTER SEVEN ON THE RACES OF (MAN->*) + tailing_frames = ( + torch.tensor([-23.0259]) + .expand([feature.size(0), params.tailing_num_frames, 80]) + .to(feature.device) + ) + feature = torch.cat([feature, tailing_frames], dim=1) + supervisions["num_frames"] += params.tailing_num_frames + + nnet_output, memory, memory_key_padding_mask = model( + feature, + supervisions, + chunk_size=chunk_size, + simulate_streaming=simulate_streaming, + ) + + assert params.method == "ctc-greedy-search" + key = "ctc-greedy-search" + batch_size = nnet_output.size(0) + maxlen = nnet_output.size(1) + topk_prob, topk_index = nnet_output.topk(1, dim=2) # (B, maxlen, 1) + topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) + topk_index = topk_index.masked_fill_( + memory_key_padding_mask, 0 + ) # (B, maxlen) + token_ids = [token_id.tolist() for token_id in topk_index] + token_ids = [ + remove_duplicates_and_blank(token_id) for token_id in token_ids + ] + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + return {key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + chunk_size: int = -1, + simulate_streaming=False, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + chunk_size: + right context to simulate streaming decoding + -1 for whole right context, i.e. non-stream decoding + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + results = [] + + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + sos_id=sos_id, + eos_id=eos_id, + chunk_size=chunk_size, + simulate_streaming=simulate_streaming, + ) + + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + if params.method == "attention-decoder": + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() + if params.avg_models is not None: + avg_models = params.avg_models.replace(",", "_") + result_file_prefix = f"epoch-avg-{avg_models}-chunksize \ + -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-" + else: + result_file_prefix = f"epoch-{params.epoch}-avg-{params.avg}-chunksize \ + -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-" + for key, results in results_dict.items(): + recog_path = ( + params.exp_dir + / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.exp_dir + / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + if params.trained_dir is not None: + params.lang_dir = Path(params.trained_dir) / "lang_bpe" + # used naming result files + params.epoch = "trained_model" + params.avg = 1 + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + causal=params.causal, + ) + + if params.trained_dir is not None: + model_name = f"{params.trained_dir}/trained_streaming_conformer.pt" + load_checkpoint(model_name, model) + elif params.avg == 1 and params.avg_models is not None: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + filenames = [] + if params.avg_models is not None: + model_ids = params.avg_models.split(",") + for i in model_ids: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + else: + start = params.epoch - params.avg + 1 + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + if params.export: + logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) + return + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + # CAUTION: `test_sets` is for displaying only. + # If you want to skip test-clean, you have to skip + # it inside the for loop. That is, use + # + # if test_set == 'test-clean': continue + # + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + test_sets = ["test-clean", "test-other"] + for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + bpe_model=bpe_model, + word_table=lexicon.word_table, + sos_id=sos_id, + eos_id=eos_id, + chunk_size=params.chunk_size, + simulate_streaming=params.simulate_streaming, + ) + + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/subsampling.py b/egs/librispeech/ASR/streaming_conformer_ctc/subsampling.py new file mode 120000 index 000000000..6fee09e58 --- /dev/null +++ b/egs/librispeech/ASR/streaming_conformer_ctc/subsampling.py @@ -0,0 +1 @@ +../conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py new file mode 100755 index 000000000..8b4d6701e --- /dev/null +++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py @@ -0,0 +1,745 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=78, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + conformer_ctc/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.8, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=True, + help="Whether to use dynamic right context during training.", + ) + + parser.add_argument( + "--short-chunk-proportion", + type=float, + default=0.7, + help="Proportion of samples trained with short right context", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Whether to do batch normalization for the + input features. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - weight_decay: The weight_decay for the optimizer. + + - lr_factor: The lr_factor for Noam optimizer. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for Noam + "weight_decay": 1e-6, + "lr_factor": 5.0, + "warm_step": 80000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: BpeCtcTrainingGraphCompiler, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model( + feature, + supervisions, + dynamic_chunk_training=params.dynamic_chunk_training, + short_chunk_proportion=params.short_chunk_proportion, + ) + # nnet_output is (N, T, C) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + token_ids = graph_compiler.texts_to_ids(texts) + + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + if params.att_rate != 0.0: + with torch.set_grad_enabled(is_training): + mmodel = model.module if hasattr(model, "module") else model + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + else: + loss = ctc_loss + att_loss = torch.tensor([0]) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: BpeCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if batch_idx % params.log_interval == 0: + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(42) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + librispeech = LibriSpeechAsrDataModule(args) + train_dl = librispeech.train_dataloaders() + valid_dl = librispeech.valid_dataloaders() + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py new file mode 100644 index 000000000..bc78e4a41 --- /dev/null +++ b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py @@ -0,0 +1,966 @@ +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from label_smoothing import LabelSmoothingLoss +from subsampling import Conv2dSubsampling, VggSubsampling +from torch.nn.utils.rnn import pad_sequence + +# Note: TorchScript requires Dict/List/etc. to be fully typed. +Supervisions = Dict[str, torch.Tensor] + + +class Transformer(nn.Module): + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + normalize_before: bool = True, + vgg_frontend: bool = False, + use_feat_batchnorm: bool = False, + ) -> None: + """ + Args: + num_features: + The input dimension of the model. + num_classes: + The output dimension of the model. + subsampling_factor: + Number of output frames is num_in_frames // subsampling_factor. + Currently, subsampling_factor MUST be 4. + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder/decoder. + num_encoder_layers: + Number of encoder layers. + num_decoder_layers: + Number of decoder layers. + dropout: + Dropout in encoder/decoder. + normalize_before: + If True, use pre-layer norm; False to use post-layer norm. + vgg_frontend: + True to use vgg style frontend for subsampling. + use_feat_batchnorm: + True to use batchnorm for the input layer. + """ + super().__init__() + self.use_feat_batchnorm = use_feat_batchnorm + if use_feat_batchnorm: + self.feat_batchnorm = nn.BatchNorm1d(num_features) + + self.num_features = num_features + self.num_classes = num_classes + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_classes) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_classes -> d_model + if vgg_frontend: + self.encoder_embed = VggSubsampling(num_features, d_model) + else: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_pos = PositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + encoder_norm = nn.LayerNorm(d_model) + else: + encoder_norm = None + + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + norm=encoder_norm, + ) + + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + ) + + if num_decoder_layers > 0: + self.decoder_num_class = ( + self.num_classes + ) # bpe model already has sos/eos symbol + + self.decoder_embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model + ) + self.decoder_pos = PositionalEncoding(d_model, dropout) + + decoder_layer = TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + decoder_norm = nn.LayerNorm(d_model) + else: + decoder_norm = None + + self.decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + norm=decoder_norm, + ) + + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) + + self.decoder_criterion = LabelSmoothingLoss() + else: + self.decoder_criterion = None + + def forward( + self, + x: torch.Tensor, + supervision: Optional[Supervisions] = None, + dynamic_chunk_training: bool = False, + short_chunk_proportion: float = 0.5, + chunk_size: int = -1, + simulate_streaming=False, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (N, T, C). + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Return a tuple containing 3 tensors: + - CTC output for ctc decoding. Its shape is (N, T, C) + - Encoder output with shape (T, N, C). It can be used as key and + value for the decoder. + - Encoder output padding mask. It can be used as + memory_key_padding_mask for the decoder. Its shape is (N, T). + It is None if `supervision` is None. + """ + if self.use_feat_batchnorm: + x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) + x = self.feat_batchnorm(x) + x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, + supervision, + dynamic_chunk_training=dynamic_chunk_training, + short_chunk_proportion=short_chunk_proportion, + chunk_size=chunk_size, + simulate_streaming=simulate_streaming, + ) + x = self.ctc_output(encoder_memory) + return x, encoder_memory, memory_key_padding_mask + + def run_encoder( + self, + x: torch.Tensor, + supervisions: Optional[Supervisions] = None, + chunk_size: int = -1, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Run the transformer encoder. + + Args: + x: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute the encoder padding mask, which is used as memory key + padding mask for the decoder. + chunk_size: right chunk_size to simulate streaming decoding + -1 for whole right context + Returns: + Return a tuple with two tensors: + - The encoder output, with shape (T, N, C) + - encoder padding mask, with shape (N, T). + The mask is None if `supervisions` is None. + It is used as memory key padding mask in the decoder. + """ + # streaming decoding(chunk_size >= 0) is only verified with Conformer + assert chunk_size == -1 + x = self.encoder_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mask = encoder_padding_mask(x.size(0), supervisions) + mask = mask.to(x.device) if mask is not None else None + x = self.encoder( + x, src_key_padding_mask=mask, chunk_size=chunk_size + ) # (T, N, C) + + return x, mask + + def ctc_output(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + The output tensor from the transformer encoder. + Its shape is (T, N, C) + + Returns: + Return a tensor that can be used for CTC decoding. + Its shape is (N, T, C) + """ + x = self.encoder_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) + return x + + @torch.jit.export + def decoder_forward( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id + + Returns: + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. + """ + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) + + device = memory.device + ys_in_pad = ys_in_pad.to(device) + ys_out_pad = ys_out_pad.to(device) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) + + decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) + + return decoder_loss + + @torch.jit.export + def decoder_nll( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[torch.Tensor], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs (e.g., word piece IDs). + Each sublist represents an utterance. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + Returns: + A 2-D tensor of shape (len(token_ids), max_token_length) + representing the cross entropy loss (i.e., negative log-likelihood). + """ + # The common part between this function and decoder_forward could be + # extracted as a separate function. + if isinstance(token_ids[0], torch.Tensor): + # This branch is executed by torchscript in C++. + # See https://github.com/k2-fsa/k2/pull/870 + # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 + token_ids = [tolist(t) for t in token_ids] + + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) + + device = memory.device + ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, B, F) + pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) + pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred_pad.view(-1, self.decoder_num_class), + ys_out_pad.view(-1), + ignore_index=-1, + reduction="none", + ) + + nll = nll.view(pred_pad.shape[0], -1) + + return nll + + +class TransformerEncoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerEncoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + normalize_before: + whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerEncoderLayer, self).__setstate__(state) + + def forward( + self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional) + + Shape: + src: (S, N, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = src + if self.normalize_before: + src = self.norm1(src) + src2 = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout1(src2) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src2) + if not self.normalize_before: + src = self.norm2(src) + return src + + +class TransformerDecoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerDecoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerDecoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerDecoderLayer, self).__setstate__(state) + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: + the sequence to the decoder layer (required). + memory: + the sequence from the last layer of the encoder (required). + tgt_mask: + the mask for the tgt sequence (optional). + memory_mask: + the mask for the memory sequence (optional). + tgt_key_padding_mask: + the mask for the tgt keys per batch (optional). + memory_key_padding_mask: + the mask for the memory keys per batch (optional). + + Shape: + tgt: (T, N, E). + memory: (S, N, E). + tgt_mask: (T, T). + memory_mask: (T, S). + tgt_key_padding_mask: (N, T). + memory_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + tgt2 = self.self_attn( + tgt, + tgt, + tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask, + )[0] + tgt = residual + self.dropout1(tgt2) + if not self.normalize_before: + tgt = self.norm1(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm2(tgt) + tgt2 = self.src_attn( + tgt, + memory, + memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = residual + self.dropout2(tgt2) + if not self.normalize_before: + tgt = self.norm2(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = residual + self.dropout3(tgt2) + if not self.normalize_before: + tgt = self.norm3(tgt) + return tgt + + +def _get_activation_fn(activation: str): + if activation == "relu": + return nn.functional.relu + elif activation == "gelu": + return nn.functional.gelu + + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) + + +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note:: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) + """ + + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. + """ + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + # not doing: self.pe = None because of errors thrown by torchscript + self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) + + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape (N, T, C). + Returns: + Return None. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + # Now pe is of shape (1, T, d_model), where T is x.size(1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding. + + Args: + x: + Its shape is (N, T, C) + + Returns: + Return a tensor of shape (N, T, C) + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1), :] + return self.dropout(x) + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + +def encoder_padding_mask( + max_len: int, supervisions: Optional[Supervisions] = None +) -> Optional[torch.Tensor]: + """Make mask tensor containing indexes of padded part. + + TODO:: + This function **assumes** that the model uses + a subsampling factor of 4. We should remove that + assumption later. + + Args: + max_len: + Maximum length of input features. + CAUTION: It is the length after subsampling. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Tensor: Mask tensor of dimension (batch_size, input_length), + True denote the masked indices. + """ + if supervisions is None: + return None + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"], + supervisions["num_frames"], + ), + 1, + ).to(torch.int32) + + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] + for idx in range(supervision_segments.size(0)): + # Note: TorchScript doesn't allow to unpack tensors as tuples + sequence_idx = supervision_segments[idx, 0].item() + start_frame = supervision_segments[idx, 1].item() + num_frames = supervision_segments[idx, 2].item() + lengths[sequence_idx] = start_frame + num_frames + + lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] + bs = int(len(lengths)) + seq_range = torch.arange(0, max_len, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) + # Note: TorchScript doesn't implement Tensor.new() + seq_length_expand = torch.tensor( + lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype + ).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + return mask + + +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: + """Generate a length mask for input. + + The masked position are filled with True, + Unmasked positions are filled with False. + + Args: + ys_pad: + padded tensor of dimension (batch_size, input_length). + ignore_id: + the ignored number (the padding number) in ys_pad + + Returns: + Tensor: + a bool tensor of the same shape as the input tensor. + """ + ys_mask = ys_pad == ignore_id + return ys_mask + + +def generate_square_subsequent_mask(sz: int) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + + For instance, if sz is 3, it returns:: + + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) + + Args: + sz: mask size + + Returns: + A square mask of dimension (sz, sz) + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask + + +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + sos_id: + The ID of the SOS token. + + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. + """ + return [[sos_id] + utt for utt in token_ids] + + +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + eos_id: + The ID of the EOS token. + + Return: + Return a new list-of-list, where each sublist ends + with EOS ID. + """ + return [utt + [eos_id] for utt in token_ids] + + +def tolist(t: torch.Tensor) -> List[int]: + """Used by jit""" + return torch.jit.annotate(List[int], t.tolist())