From a75f75bbad3b12973f86d127d43856c087e50fcb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 18 Sep 2021 11:34:35 +0800 Subject: [PATCH] Fix bugs --- .../ASR/conformer_ctc_bn_2d/conformer.py | 402 +++++++----------- .../ASR/conformer_ctc_bn_2d/transformer.py | 3 + 2 files changed, 167 insertions(+), 238 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py index 249997a39..4aca39ec0 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py @@ -16,19 +16,23 @@ # limitations under the License. import copy +import random import math import warnings -from typing import Optional, Tuple +from typing import Optional, Tuple, List import torch_flow_sampling import torch from torch import Tensor, nn +from subsampling import Conv2dSubsampling, VggSubsampling from transformer import Supervisions, TransformerEncoderLayer, TransformerDecoderLayer, encoder_padding_mask, \ - LabelSmoothingLoss, PositionalEncoding, pad_sequence, add_sos, add_eos, decoder_padding_mask + LabelSmoothingLoss, PositionalEncoding, pad_sequence, add_sos, add_eos, decoder_padding_mask, \ + generate_square_subsequent_mask class ConformerTrunk(nn.Module): - def __init__(num_features: int, + def __init__(self, + num_features: int, subsampling_factor: int = 4, d_model: int = 256, nhead: int = 4, @@ -37,6 +41,7 @@ class ConformerTrunk(nn.Module): dropout: float = 0.1, cnn_module_kernel: int = 31, use_feat_batchnorm: bool = True) -> None: + super(ConformerTrunk, self).__init__() if use_feat_batchnorm: self.feat_batchnorm = nn.BatchNorm1d(num_features) @@ -62,12 +67,11 @@ class ConformerTrunk(nn.Module): cnn_module_kernel, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - + self.encoder = ConformerEncoder(encoder_layer, num_layers) def forward( self, x: torch.Tensor, supervision: Optional[Supervisions] = None - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Args: x: @@ -78,12 +82,10 @@ class ConformerTrunk(nn.Module): (CAUTION: It contains length information, i.e., start and number of frames, before subsampling) - Returns: - Return a tuple containing 2 tensors: - - Encoder output with shape [T, N, C]. It can be used as key and - value for the decoder. - - Encoder output padding mask. It can be used as - memory_key_padding_mask for the decoder. Its shape is [N, T]. + Return (output, pos_emb, mask), where: + output: The output embedding, of shape (T, N, C). + pos_emb: The positional embedding (this will be used by ctc_encoder forward). + mask: The output padding mask, a Tensor of bool, of shape [N, T]. It is None if `supervision` is None. """ if hasattr(self, 'feat_batchnorm'): @@ -92,13 +94,15 @@ class ConformerTrunk(nn.Module): x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] x = self.feat_embed(x) - x = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - mask = encoder_padding_mask(x.size(0), supervisions) - mask = mask.to(x.device) if mask is not None else None - x = self.encoder(x, key_padding_mask=mask) # (T, N, C) - return x, mask + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mask = encoder_padding_mask(x.size(0), supervision) + + mask = mask.to(x.device) if mask is not None else None + x = self.encoder(x, pos_emb=pos_emb, key_padding_mask=mask) # (T, N, C) + + return x, pos_emb, mask class BidirectionalConformer(nn.Module): @@ -178,25 +182,26 @@ class BidirectionalConformer(nn.Module): in the discrete bottleneck """ def __init__( - num_features: int, - num_classes: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_trunk_encoder_layers: int = 12, - num_ctc_encoder_layers: int = 4, - num_decoder_layers: int = 6, - num_reverse_encoder_layers: int = 4, - num_reverse_decoder_layers: int = 4, - num_self_predictor_layers: int = 3, - bypass_bottleneck: bool = True, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - is_bpe: bool = False, - use_feat_batchnorm: bool = True, - discrete_bottleneck_tot_classes: int = 512, - discrete_bottleneck_num_groups: int = 4 + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_trunk_encoder_layers: int = 12, + num_ctc_encoder_layers: int = 4, + num_decoder_layers: int = 6, + num_reverse_encoder_layers: int = 4, + num_reverse_decoder_layers: int = 4, + num_self_predictor_layers: int = 3, + bypass_bottleneck: bool = True, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + is_bpe: bool = False, + use_feat_batchnorm: bool = True, + discrete_bottleneck_tot_classes: int = 512, + discrete_bottleneck_num_groups: int = 4 ) -> None: super(BidirectionalConformer, self).__init__() @@ -236,7 +241,7 @@ class BidirectionalConformer(nn.Module): self.token_embed_scale = d_model ** 0.5 self.token_embed = nn.Embedding( num_embeddings=self.decoder_num_class, embedding_dim=d_model, - _weight=torch.randn(self.num_classes, d_model) * (1 / self.embed_scale) + _weight=torch.randn(self.decoder_num_class, d_model) * (1 / self.token_embed_scale) ) decoder_layer = TransformerDecoderLayer( @@ -315,8 +320,7 @@ class BidirectionalConformer(nn.Module): if num_self_predictor_layers > 0: encoder_layer = SimpleCausalEncoderLayer(d_model, dropout=dropout) - self.self_predictor_encoder = simple_causal_encoder(encoder_layer, - num_self_predictor_layers) + self.self_predictor_encoder = encoder_layer self.discrete_bottleneck = DiscreteBottleneck( @@ -326,8 +330,8 @@ class BidirectionalConformer(nn.Module): - def forward(self, x: Tensor, supervision: Optional[Supervisions], - need_softmax: bool = True) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + def forward(self, x: Tensor, supervision: Optional[Supervisions] = None, + need_softmax: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor, Optional[Tensor]]: """ Forward function that "encodes" the features. @@ -343,26 +347,29 @@ class BidirectionalConformer(nn.Module): If true, the last output ("softmax") will be computed. This can be useful in the reverse model, but only necessary if straight_through_scale != 1.0. - Returns: (memory, bn_memory, sampled, softmax, key_padding_mask), where: + Returns: (memory, bn_memory, pos_emb, sampled, softmax, key_padding_mask), where: - `memory` is a Tensor of shape [T, N, E] i.e. [T, batch_size, embedding_dim] where T - is actually a subsampled form of the num_frames of the input `x`. - If self.bypass_bottleneck, it will be taken before the discrete - bottleneck; otherwise, from after. - `bn_memory` is the same shape as `memory`, but comes after the discrete bottleneck - regardless of the value of self.bypass_bottleneck. - `sampled` is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes` - as given to the constructor. This will be needed for the 'reverse' model. - `softmax` is a "soft" version of `sampled`. Will only be returned if need_softmax == True; + memory: a Tensor of shape [T, N, E] i.e. [T, batch_size, embedding_dim] where T + is actually a subsampled form of the num_frames of the input `x`. + If self.bypass_bottleneck, it will be taken before the discrete + bottleneck; otherwise, from after. + bn_memory: The same shape as `memory`, but comes after the discrete bottleneck + regardless of the value of self.bypass_bottleneck. + pos_emb: The relative positional embedding; will be given to ctc_encoder_forward() + sampled: a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes` + as given to the constructor. This will be needed for the 'reverse' model. + softmax: a "soft" version of `sampled`. Will only be returned if need_softmax == True; else will be None. + key_padding_mask: The padding mask for the "memory" output, a Tensor of bool of + shape [N, T] (only if supervision was supplied, else None). """ - encoder_output, memory_key_padding_mask = self.trunk(x, supervision) + encoder_output, pos_emb, memory_key_padding_mask = self.trunk(x, supervision) bn_memory, sampled, softmax = self.discrete_bottleneck(encoder_output) memory = encoder_output if self.bypass_bottleneck else bn_memory - return (memory, bn_memory, sampled, softmax, memory_key_padding_mask) + return (memory, bn_memory, pos_emb, sampled, softmax, memory_key_padding_mask) def decoder_forward( self, @@ -373,11 +380,13 @@ class BidirectionalConformer(nn.Module): eos_id: int, ) -> torch.Tensor: """ + Compute the decoder loss function (given a particular list of hypotheses). + Args: memory: - It's the first output of forward(), with shape [T, N, E] + The first output of forward(), with shape [T, N, E] memory_key_padding_mask: - The padding mask from forward() + The padding mask from forward(), a tensor of bool with shape [N, T] token_ids: A list-of-list IDs. Each sublist contains IDs for an utterance. The IDs can be either phone IDs or word piece IDs. @@ -390,6 +399,7 @@ class BidirectionalConformer(nn.Module): A scalar, the **sum** of label smoothing loss over utterances in the batch without any normalization. """ + ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) @@ -426,6 +436,7 @@ class BidirectionalConformer(nn.Module): tgt_mask=tgt_mask, memory_key_padding_mask=memory_key_padding_mask, ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) @@ -436,6 +447,7 @@ class BidirectionalConformer(nn.Module): def ctc_encoder_forward( self, memory: torch.Tensor, + pos_emb: torch.Tensor, memory_key_padding_mask: torch.Tensor, ) -> torch.Tensor: """ @@ -444,17 +456,19 @@ class BidirectionalConformer(nn.Module): Args: memory: - It's the output of forward(), with shape [T, N, E] + It's the output of forward(), with shape (T, N, E) + pos_emb: + Relative positional embedding tensor: (N, 2*T-1, E) memory_key_padding_mask: - The padding mask from forward() + The padding mask from forward(), a tensor of bool of shape (N, T) Returns: A Tensor with shape [N, T, C] where C is the number of classes (e.g. number of phones or word pieces). Contains normalized log-probabilities. """ - x = self.ctc_encoder(memory, + pos_emb, key_padding_mask=memory_key_padding_mask) x = self.ctc_output_layer(x) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -470,10 +484,10 @@ class BidirectionalConformer(nn.Module): softmax: Optional[torch.Tensor], reverse_gradient: bool = True) -> Tensor: """ - Returns the total log-prob of the the - labels sampled in the discrete bottleneck layer, as predicted using a relatively - simple model from previous frames sampled from the bottleneck layer. - [Appears on the denominator of an expressin for mutual information]. + Returns the total log-prob of the the labels sampled in the discrete + bottleneck layer, as predicted using a relatively simple model that + predicts from previous frames sampled from the bottleneck layer. + [Appears on the denominator of an expression for mutual information]. Args: memory_shifted: @@ -482,21 +496,25 @@ class BidirectionalConformer(nn.Module): (T, N, E) = memory.shape memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0) memory_key_padding_mask: - The padding mask from the encoder. - sampled: is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes` - as given to the constructor. This will be needed for the 'reverse' model. + The padding mask from the encoder, of shape [N, T], boolean, True + for masked locations. + sampled: sampled and interpolated one-hot values, as a Tensor of shape [T, N, C] + where C corresponds to `discrete_bottleneck_tot_classes` + as given to the constructor. This will be needed for the 'reverse' + model. softmax: is a "soft" version of `sampled`; if None, will default to `sampled`. reverse_gradient: will likely be true. If true, the gradient is reversed twice in this computation, so that we train predictors with the correct gradient, i.e. to predict, not anti-predict (since the return value of this function will appear with positive, not negative, sign in the loss function, so will be minimized). + The gradient w.r.t. the non-self inputs to this function, though (i.e. + memory_shifted, sampled, softmax) will not be reversed, though. Returns: A scalar tensor, the **sum** of label smoothing loss over utterances in the batch without any normalization. """ - if reverse_gradient: # Reversing gradient for memory_shifted puts the gradient back into # the correct sign; we reversed it in @@ -506,6 +524,8 @@ class BidirectionalConformer(nn.Module): # what happens to the gradients). memory_shifted = ReverseGrad.apply(memory_shifted) + # no mask is needed for self_predictor_encoder; its CNN + # layer uses left-padding only, making it causal. predictor = self.self_predictor_encoder(memory_shifted) prob = self.discrete_bottleneck.compute_prob(predictor, @@ -566,15 +586,21 @@ class BidirectionalConformer(nn.Module): tokens_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=padding_id) + # Let S be the length of the longest sentence (padded) + token_embedding = self.token_embed(tokens_padded) * self.token_embed_scale # (N, S) -> (N, S, C) + # add absolute position-encoding information + token_embedding = self.abs_pos(token_embedding) + + token_embedding = token_embedding.permute(1, 0, 2) # (N, S, C) -> (S, N, C) + + token_memory = self.reverse_encoder(token_embedding, + src_key_padding_mask=tokens_key_padding_mask) + # token_memory is of shape (S, N, C), if S is length of token sequence. + T = memory.shape[0] # the targets, here, are the hidden discrete symbols we are predicting tgt_mask = generate_square_subsequent_mask(T, device=memory.device) - token_embedding = self.token_embed(tokens_padded) * self.token_embed_scale - token_memory = self.reverse_encoder(token_embedding, - src_key_padding_mask=tokens_key_padding_mask) - # tokens_encoded is of shape (S, N, C), if S is length of token sequence. - hidden_predictor = self.reverse_decoder( tgt=memory_shifted, memory=token_memory, @@ -584,25 +610,13 @@ class BidirectionalConformer(nn.Module): total_prob = self.discrete_bottleneck.compute_prob( hidden_predictor, - TODO, # HERE - ) + sampled, + softmax, + memory_key_padding_mask) - tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) # (T, N, C) - pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) + # TODO: consider using a label-smoothed loss. + return total_prob - decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) - - return decoder_loss class SimpleCausalEncoderLayer(nn.Module): """ @@ -650,129 +664,6 @@ class SimpleCausalEncoderLayer(nn.Module): src = self.norm_final(src) return src -# for search: SimpleCausalEncoder -def simple_causal_encoder(encoder_layer: nn.Module, - num_layers: int): - return torch.nn.Sequential([copy.deepcopy(encoder_leyer) for _ in range(num_layers)]) - - - -class DiscreteBottleneckConformer(Transformer): - """ - Args: - num_features (int): Number of input features - num_classes (int): Number of output classes - subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension - nhead (int): number of head - dim_feedforward (int): feedforward dimention - num_encoder_layers (int): number of encoder layers - num_decoder_layers (int): number of decoder layers - dropout (float): dropout rate - cnn_module_kernel (int): Kernel size of convolution module - discrete_bottleneck_pos (int): position in the encoder at which to place - the discrete bottleneck (this many encoder layers will - precede it) - - """ - - def __init__( - self, - num_features: int, - num_classes: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - num_decoder_layers: int = 6, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - normalize_before: bool = True, - vgg_frontend: bool = False, - is_espnet_structure: bool = False, - mmi_loss: bool = True, - use_feat_batchnorm: bool = False, - discrete_bottleneck_pos: int = 8, - discrete_bottleneck_tot_classes: int = 512, - discrete_bottleneck_num_groups: int = 2 - ) -> None: - super(DiscreteBottleneckConformer, self).__init__( - num_features=num_features, - num_classes=num_classes, - subsampling_factor=subsampling_factor, - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - num_encoder_layers=num_encoder_layers, - num_decoder_layers=num_decoder_layers, - dropout=dropout, - normalize_before=normalize_before, - vgg_frontend=vgg_frontend, - mmi_loss=mmi_loss, - use_feat_batchnorm=use_feat_batchnorm, - ) - - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - - encoder_layer = ConformerEncoderLayer( - d_model, - nhead, - dim_feedforward, - dropout, - cnn_module_kernel, - normalize_before, - is_espnet_structure, - ) - - discrete_bottleneck = DiscreteBottleneck(dim=d_model, - tot_classes=discrete_bottleneck_tot_classes, - num_groups=discrete_bottleneck_num_groups) - - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, - discrete_bottleneck=discrete_bottleneck, - discrete_bottleneck_pos=discrete_bottleneck_pos) - - - self.normalize_before = normalize_before - self.is_espnet_structure = is_espnet_structure - if self.normalize_before and self.is_espnet_structure: - self.after_norm = nn.LayerNorm(d_model) - else: - # Note: TorchScript detects that self.after_norm could be used inside forward() - # and throws an error without this change. - self.after_norm = identity - - def run_encoder( - self, x: Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[Tensor, Optional[Tensor]]: - """ - Args: - x: - The model input. Its shape is [N, T, C]. - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute encoder padding mask, which is used as memory key padding - mask for the decoder. - - Returns: - Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). - Tensor: Mask tensor of dimension (batch_size, input_length) - """ - x = self.feat_embed(x) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - mask = encoder_padding_mask(x.size(0), supervisions) - if mask is not None: - mask = mask.to(x.device) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) - - return x, mask - class ReverseGrad(torch.autograd.Function): def apply(ctx, x): @@ -878,9 +769,9 @@ class DiscreteBottleneck(nn.Module): # [ True, True, True, True]]) self.register_buffer('pred_cross_mask', ((torch.arange(d) // c).unsqueeze(1) >= (torch.arange(d) // c).unsqueeze(0))) - self.reset_parameters() + self._reset_parameters() - def reset_parameters_(self): + def _reset_parameters(self): if hasattr(self, 'pred_cross'): torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5)) @@ -967,7 +858,8 @@ class DiscreteBottleneck(nn.Module): This will be useful in computing a loss function that has a likelihood term with negative sign (i.e. the self-prediction). We'll later need negate the gradient one more more time - where we give the input to whatever module generated 'x'. + where we give the input to the prediction module that + generated 'x'. Returns a scalar Tensor represnting the total probability. """ @@ -980,6 +872,8 @@ class DiscreteBottleneck(nn.Module): logprobs = self.pred_linear(x) + # Add "cross-terms" to logprobs; this is a regression that uses earlier + # groups to predict later groups if self.num_groups > 1: pred_cross = self.pred_cross * self.pred_cross_mask t = self.tot_classes @@ -999,11 +893,12 @@ class DiscreteBottleneck(nn.Module): logprobs = logprobs.reshape(S, N, self.num_groups, self.classes_per_group) # Normalize the log-probs (so they sum to one) logprobs = torch.nn.functional.logsoftmax(logprobs, dim=-1) + logprobs = logprobs.reshape(S, N, C) if padding_mask is not None: assert padding_mask.dtype == torch.bool and padding_mask.shape == (N, S) padding_mask = torch.logical_not(padding_mask).transpose(0, 1).unsqueeze(-1) - # padding_mask.shape == (S, N, E) + assert padding_mask.shape == (S, N, 1) tot_prob = (logprobs * softmax * padding_mask).sum() else: tot_prob = (logprobs * softmax).sum() @@ -1080,8 +975,8 @@ class ConformerEncoderLayer(nn.Module): self, src: Tensor, pos_emb: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """ Pass the input through the encoder layer. @@ -1089,17 +984,18 @@ class ConformerEncoderLayer(nn.Module): Args: src: the sequence to the encoder layer (required). pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). + attn_mask: the mask for the src sequence (optional). + key_padding_mask: the mask for the src keys per batch (optional). Shape: src: (S, N, E). pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). + attn_mask: (S, S). This probably won't be used, in fact should not + be (e.g. could in principle ensure causal behavior, but + actually the conformer does not support this). + key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number """ - # macaron style feed forward module residual = src src = self.norm_ff_macaron(src) @@ -1115,8 +1011,8 @@ class ConformerEncoderLayer(nn.Module): src, src, pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, )[0] src = residual + self.dropout(src_att) @@ -1141,7 +1037,6 @@ class ConformerEncoder(nn.Module): Args: encoder_layer: an instance of the ConformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). - norm: the layer normalization component (optional). Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -1153,7 +1048,7 @@ class ConformerEncoder(nn.Module): def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: super(ConformerEncoder, self).__init__() - self.layers = torch.nn.ModuleList([copy.deepcopy(encoder_leyer) for _ in range(num_layers)]) + self.layers = torch.nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)]) def forward( self, @@ -1183,7 +1078,6 @@ class ConformerEncoder(nn.Module): return x - class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -1313,7 +1207,6 @@ class RelPositionMultiheadAttention(nn.Module): nn.init.xavier_uniform_(self.in_proj.weight) nn.init.constant_(self.in_proj.bias, 0.0) nn.init.constant_(self.out_proj.bias, 0.0) - nn.init.xavier_uniform_(self.pos_bias_u) nn.init.xavier_uniform_(self.pos_bias_v) @@ -1663,6 +1556,7 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( attn_output_weights, p=dropout_p, training=training ) @@ -1701,7 +1595,7 @@ class CausalConvolutionModule(nn.Module): self, channels: int, kernel_size: int, bias: bool = True ) -> None: """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() + super(CausalConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 self.kernel_size = kernel_size @@ -1752,11 +1646,12 @@ class CausalConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv - (B, C, T) = x + (B, C, T) = x.shape padding = self.kernel_size - 1 x = torch.cat((torch.zeros(B, C, padding, device=x.device, dtype=x.type), x), dim=2) - x = self.depthwise_conv(x) # <-- This has no padding. + x = self.depthwise_conv(x) # <-- This convolution module does no padding, + # so we padded manually, on the left only. x = self.activation(self.norm(x)) @@ -1835,7 +1730,7 @@ class ConvolutionModule(nn.Module): x = self.pointwise_conv2(x) # (batch, channel, time) - return x.permute(2, 0, 1) + return x.permute(2, 0, 1) # (time, batch channel) class Swish(torch.nn.Module): @@ -1851,16 +1746,47 @@ def identity(x): -def test_discrete_bottleneck_conformer(): +def _gen_rand_tokens(N: int) -> List[List[int]]: + ans = [] + for _ in range(N): + S = random.randint(1, 20) + ans.append([random.randint(3, 30) for _ in range(S)]) + return ans + +def _gen_supervision(tokens: List[List[int]]): + ans = dict() + N = len(tokens) + ans['sequence_idx'] = torch.arange(N, dtype=torch.int32) + ans['start_frame'] = torch.zeros(N, dtype=torch.int32) + ans['num_frames'] = torch.tensor([ random.randint(20, 35) for _ in tokens]) + return ans + +def _test_bidirectional_conformer(): num_features = 40 num_classes = 1000 - m = DiscreteBottleneckConformer(num_features, num_classes) + m = BidirectionalConformer(num_features, num_classes) T = 35 N = 10 C = num_features feats = torch.randn(N, T, C) - ctc_output, _, _ = m(feats) - # [N, T, C]. + + tokens = _gen_rand_tokens(N) + supervision = _gen_supervision(tokens) + print("tokens = ", tokens) + print("supervision = ", supervision) + # memory: [T, N, C] + (memory, bn_memory, pos_emb, sampled, softmax, key_padding_mask) = m(feats, supervision) + + # ctc_output: [N, T, C]. + ctc_output = m.ctc_encoder_forward(memory, pos_emb, key_padding_mask) + + decoder_loss = m.decoder_forward(memory, key_padding_mask, tokens, + sos_id=1, + eos_id=2) + + (T, N, E) = memory.shape + memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0) + if __name__ == '__main__': - test_discrete_bottleneck_conformer() + _test_bidirectional_conformer() diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/transformer.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/transformer.py index 5d5213f56..0c727ec74 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/transformer.py @@ -914,6 +914,9 @@ def encoder_padding_mask( ).unsqueeze(-1) mask = seq_range_expand >= seq_length_expand + # Assert that in each row, i.e. each utterance, at least one frame is not + # masked. Otherwise it may lead to nan's appearing in the attention computation. + assert torch.all(torch.sum(torch.logical_not(mask), dim=1) != 0) return mask