diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py index 566cad8cf..249997a39 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import copy import math import warnings from typing import Optional, Tuple @@ -23,7 +23,638 @@ import torch_flow_sampling import torch from torch import Tensor, nn -from transformer import Supervisions, Transformer, encoder_padding_mask +from transformer import Supervisions, TransformerEncoderLayer, TransformerDecoderLayer, encoder_padding_mask, \ + LabelSmoothingLoss, PositionalEncoding, pad_sequence, add_sos, add_eos, decoder_padding_mask + + +class ConformerTrunk(nn.Module): + def __init__(num_features: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_layers: int = 10, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + use_feat_batchnorm: bool = True) -> None: + if use_feat_batchnorm: + self.feat_batchnorm = nn.BatchNorm1d(num_features) + + self.num_features = num_features + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.feat_embed converts the input of shape [N, T, num_classes] + # to the shape [N, T//subsampling_factor, d_model]. + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_classes -> d_model + self.feat_embed = VggSubsampling(num_features, d_model) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + cnn_module_kernel, + ) + + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + + + def forward( + self, x: torch.Tensor, supervision: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is [N, T, C]. + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Return a tuple containing 2 tensors: + - Encoder output with shape [T, N, C]. It can be used as key and + value for the decoder. + - Encoder output padding mask. It can be used as + memory_key_padding_mask for the decoder. Its shape is [N, T]. + It is None if `supervision` is None. + """ + if hasattr(self, 'feat_batchnorm'): + x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] + x = self.feat_batchnorm(x) + x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + + x = self.feat_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mask = encoder_padding_mask(x.size(0), supervisions) + mask = mask.to(x.device) if mask is not None else None + x = self.encoder(x, key_padding_mask=mask) # (T, N, C) + + return x, mask + + +class BidirectionalConformer(nn.Module): + """ + This is a modified conformer where the encoder outputs the probabilities of + hidden discrete symbols. These probabilities are sampled from and then + given as input to two separate "forward" decoders: the CTC decoder and the + attention decoder. We also have a reverse attention decoder where we + predict these discrete symbols from the word/phone/word-piece sequences. + From that we subtract the log-probs of the symbols given a simple "self-decoder" + that predicts those symbols given previous symbols; this avoids the + symbols converging on the most likely ones as a result of the reverse + attention decoder's contribution to the loss function. + + Caution: this code has several different 'forward' functions: the + regular 'forward' function is for the encoder, but there are others: + + forward(), [generates sampled symbols from the conformer encoder, + also returning some additional things, e.g. the + pre-sampling symbol probabilities.] + decoder_forward(), [predicts word-piece symbols from hidden symbols, returns + scalar total log-like] + ctc_encoder_forward() [caution: this just gives the loglikes, + it does not do CTC decoding] + reverse_decoder_forward() [predicts hidden symbols from word-piece + symbols] + self_prediction_forward() [predicts hidden symbols from previous hidden + symbols using a simple model, returns scalar + total log-like. We subtract this from the + loss function as a mechanism to avoid + "trivial" solutions; this can also be justified + from a point of view of maximizing mutual information]. + + Args: + num_features: Input acoustic feature dimension, e.g. 40. + num_classes: Output dimension which might be number of phones or + word-piece symbols, including blank/sos/eos. + subsampling_factor: Factor by which acoustic features are + downsampled in encoder. + d_model: Dimension for attention computations + nhead: Number of heads in attention computations + dim_feedforward: Dimension for feedforward computations in + conformer + num_encoder_layers: Number of encoder layers in the "trunk" that + encodes the acoustic features + num_ctc_encoder_layers: Number of layers in the CTC encoder + that comes after the trunk (and possibly the discrete + bottleneck, if bypass_bottleneck == True. + These are just conformer encoder layers. + num_decoder_layers: Number of layers in the attention decoder; + this goes from the trunk to the word-pieces or phones. + num_reverse_encoder_layers: Number of layers in the reverse + encoder, which encodes the word-pieces or phones + and whose output will be used to predict the + discrete bottleneck. + num_reverse_decoder_layers: Number of layers in the reverse + encoder, which predicts the discrete-bottleneck + samples from the word sequence. + num_self_predictor_layers: Number of layers in the simple + self-predictor model that predicts the discrete-bottleneck + samples from its own previous frames. This is + intended to be a relatively simple model because its main + useful function is to prevent "trivial" solutions + such as collapse of the distribution to a single symbol, + or symbols that are highly correlated across time. + bypass_bottleneck: If true, bypass the discrete bottleneck + when predicting the CTC output and the decoder + that decodes the word-pieces or phones. + dropout: Dropout probability + cnn_module_kernel: Kernel size in forward conformer layers + is_bpe: If false, we'll add one (for EOS) to the number of + classes at the output of the decoder + use_feat_batchnorm: If true, apply batchnorm to the input features. + discrete_bottleneck_tot_classes: Total number of classes + (across all groups) in the discrete bottleneck + discrete_bottleneck_num_groups: Number of groups of classes/symbols + in the discrete bottleneck + """ + def __init__( + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_trunk_encoder_layers: int = 12, + num_ctc_encoder_layers: int = 4, + num_decoder_layers: int = 6, + num_reverse_encoder_layers: int = 4, + num_reverse_decoder_layers: int = 4, + num_self_predictor_layers: int = 3, + bypass_bottleneck: bool = True, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + is_bpe: bool = False, + use_feat_batchnorm: bool = True, + discrete_bottleneck_tot_classes: int = 512, + discrete_bottleneck_num_groups: int = 4 + ) -> None: + super(BidirectionalConformer, self).__init__() + + self.bypass_bottleneck = bypass_bottleneck + + self.trunk = ConformerTrunk(num_features, subsampling_factor, + d_model, nhead, dim_feedforward, + num_trunk_encoder_layers, dropout, + cnn_module_kernel, + use_feat_batchnorm) + + self.num_features = num_features + self.num_classes = num_classes + self.subsampling_factor = subsampling_factor + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + cnn_module_kernel, + ) + self.ctc_encoder = ConformerEncoder(encoder_layer, num_ctc_encoder_layers) + self.ctc_output_layer = nn.Sequential( + nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + ) + + # absolute position encoding, used by various layer types + self.abs_pos = PositionalEncoding(d_model, dropout) + + if num_decoder_layers > 0: + # extra class for sos/eos symbol, if not BPE + self.decoder_num_class = self.num_classes if is_bpe else self.num_classes + 1 + + # self.embed is the token embedding (embedding for phones or + # word-pieces) that is used for both the forward and reverse decoders + self.token_embed_scale = d_model ** 0.5 + self.token_embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model, + _weight=torch.randn(self.num_classes, d_model) * (1 / self.embed_scale) + ) + + decoder_layer = TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + ) + + self.decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + norm=nn.LayerNorm(d_model) + ) + + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) + + # Caution: it takes padding_idx=-1 as a default. That's the + # target value it will ignore. + self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) + else: + self.decoder_criterion = None + + + if num_reverse_encoder_layers > 0: + self.reverse_encoder_pos = PositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + ) + + self.reverse_encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, + num_layers=num_reverse_encoder_layers, + norm=nn.LayerNorm(d_model)) + + if num_reverse_decoder_layers > 0: + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + ) + + self.reverse_encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, + num_layers=num_reverse_encoder_layers, + norm=nn.LayerNorm(d_model)) + + + if num_reverse_decoder_layers > 0: + + self.reverse_decoder_pos = PositionalEncoding(d_model, dropout) + + decoder_layer = TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + ) + + self.reverse_decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=num_reverse_decoder_layers, + norm=nn.LayerNorm(d_model) + ) + # There is no "linear output" for the reverse decoder; + # that is handled by the discrete_bottleneck layer itself. + # It just accepts the output of self.reverse_decoder as + # the input to its prediction mechanism. + + if num_self_predictor_layers > 0: + encoder_layer = SimpleCausalEncoderLayer(d_model, + dropout=dropout) + self.self_predictor_encoder = simple_causal_encoder(encoder_layer, + num_self_predictor_layers) + + + self.discrete_bottleneck = DiscreteBottleneck( + dim=d_model, + tot_classes=discrete_bottleneck_tot_classes, + num_groups=discrete_bottleneck_num_groups) + + + + def forward(self, x: Tensor, supervision: Optional[Supervisions], + need_softmax: bool = True) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + """ + Forward function that "encodes" the features. + + Args: + x: + The input tensor. Its shape is [N, T, F], i.e. [batch_size, num_frames, num_features]. + supervision: + Supervision in lhotse format (optional) + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling). Used only to compute masking information. + need_softmax: + If true, the last output ("softmax") will be computed. This can be useful + in the reverse model, but only necessary if straight_through_scale != 1.0. + + Returns: (memory, bn_memory, sampled, softmax, key_padding_mask), where: + + `memory` is a Tensor of shape [T, N, E] i.e. [T, batch_size, embedding_dim] where T + is actually a subsampled form of the num_frames of the input `x`. + If self.bypass_bottleneck, it will be taken before the discrete + bottleneck; otherwise, from after. + `bn_memory` is the same shape as `memory`, but comes after the discrete bottleneck + regardless of the value of self.bypass_bottleneck. + `sampled` is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes` + as given to the constructor. This will be needed for the 'reverse' model. + `softmax` is a "soft" version of `sampled`. Will only be returned if need_softmax == True; + else will be None. + """ + encoder_output, memory_key_padding_mask = self.trunk(x, supervision) + + bn_memory, sampled, softmax = self.discrete_bottleneck(encoder_output) + + memory = encoder_output if self.bypass_bottleneck else bn_memory + + return (memory, bn_memory, sampled, softmax, memory_key_padding_mask) + + def decoder_forward( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the first output of forward(), with shape [T, N, E] + memory_key_padding_mask: + The padding mask from forward() + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id + + Returns: + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. + """ + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) + + device = memory.device + ys_in_pad = ys_in_pad.to(device) + ys_out_pad = ys_out_pad.to(device) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.token_embed(ys_in_pad) * self.token_embed_scale # (N, T) -> (N, T, C) + tgt = self.abs_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # ,tgt_key_padding_mask=tgt_key_padding_mask,... + # We don't supply tgt_key_padding_mask because it's useless; thanks to tgt_mask, + # those positions would already be excluded from consideration for any output + # position that we're going to care about. + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) + + decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) + + return decoder_loss + + def ctc_encoder_forward( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Passes the output of forward() through the CTC encoder and the CTC + output to give the output that can be given to the CTC loss function + + Args: + memory: + It's the output of forward(), with shape [T, N, E] + memory_key_padding_mask: + The padding mask from forward() + + Returns: + A Tensor with shape [N, T, C] where C is the number of classes + (e.g. number of phones or word pieces). Contains normalized + log-probabilities. + """ + + x = self.ctc_encoder(memory, + key_padding_mask=memory_key_padding_mask) + x = self.ctc_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) + return x + + + def self_prediction_forward( + self, + memory_shifted: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + sampled: torch.Tensor, + softmax: Optional[torch.Tensor], + reverse_gradient: bool = True) -> Tensor: + """ + Returns the total log-prob of the the + labels sampled in the discrete bottleneck layer, as predicted using a relatively + simple model from previous frames sampled from the bottleneck layer. + [Appears on the denominator of an expressin for mutual information]. + + Args: + memory_shifted: + It's the output of forward(), with shape [T, N, E], shifted + by one so that shifted_memory[t] == memory[t-1], as in: + (T, N, E) = memory.shape + memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0) + memory_key_padding_mask: + The padding mask from the encoder. + sampled: is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes` + as given to the constructor. This will be needed for the 'reverse' model. + softmax: is a "soft" version of `sampled`; if None, will default to `sampled`. + reverse_gradient: will likely be true. If true, the gradient is reversed twice + in this computation, so that we train predictors with the correct + gradient, i.e. to predict, not anti-predict (since the return value + of this function will appear with positive, not negative, sign in the + loss function, so will be minimized). + Returns: + A scalar tensor, the **sum** of label smoothing loss over utterances + in the batch without any normalization. + """ + + + if reverse_gradient: + # Reversing gradient for memory_shifted puts the gradient back into + # the correct sign; we reversed it in + # self.discrete_bottleneck.compute_prob(), in order that + # self.self.predictor_encoder can learn to predict. + # (You have to read the code in reverse, to reason about + # what happens to the gradients). + memory_shifted = ReverseGrad.apply(memory_shifted) + + predictor = self.self_predictor_encoder(memory_shifted) + + prob = self.discrete_bottleneck.compute_prob(predictor, + sampled, softmax, + memory_key_padding_mask, + reverse_gradient=True) + return prob + + + def reverse_decoder_forward( + self, + memory_shifted: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + sampled: torch.Tensor, + softmax: Optional[torch.Tensor], + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + padding_id: int, + ) -> torch.Tensor: + """ + This is the reverse decoder function, which returns the total probability of the + labels sampled in the discrete bottleneck layer, as predicted from the + supervision word-sequence. + + Args: + memory_shifted: + It's the output of forward(), with shape [T, N, E], shifted + by one so that shifted_memory[t] == memory[t-1], as in: + (T, N, E) = memory.shape + memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0) + memory_key_padding_mask: + The padding mask from the encoder. + sampled: is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes` + as given to the constructor. This will be needed for the 'reverse' model. + softmax: is a "soft" version of `sampled`; if None, will default to `sampled`. + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id + padding_id: + token id used for padding of the `token_ids` when they appear as the + input, e.g. blank id or eos_id. + Returns: + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. + """ + + # Add both sos and eos symbols to token_ids. These will be used + # as an input, there is no harm in adding both of these. + token_ids = ([sos_id] + utt + [eos_id] for utt in token_ids) + + tokens_padded = pad_sequence(token_ids, batch_first=True, padding_value=padding_id).to(memory.device) + + tokens_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=padding_id) + + + T = memory.shape[0] + # the targets, here, are the hidden discrete symbols we are predicting + tgt_mask = generate_square_subsequent_mask(T, device=memory.device) + + token_embedding = self.token_embed(tokens_padded) * self.token_embed_scale + token_memory = self.reverse_encoder(token_embedding, + src_key_padding_mask=tokens_key_padding_mask) + # tokens_encoded is of shape (S, N, C), if S is length of token sequence. + + hidden_predictor = self.reverse_decoder( + tgt=memory_shifted, + memory=token_memory, + tgt_mask=tgt_mask, + memory_key_padding_mask=tokens_key_padding_mask) + + + total_prob = self.discrete_bottleneck.compute_prob( + hidden_predictor, + TODO, # HERE + ) + + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) + + decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) + + return decoder_loss + +class SimpleCausalEncoderLayer(nn.Module): + """ + This is a simple encoder layer that only sees left-context; it is + based on the ConformerEncoderLayer, with the attention and one of + the feed-forward components stripped out. + """ + def __init__(self, + d_model: int, + dim_feedforward: int = 512, + dropout: float = 0.1, + cnn_module_kernel: int = 15): + super(SimpleCausalEncoderLayer, self).__init__() + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + self.conv_module = CausalConvolutionModule(d_model, + cnn_module_kernel) + + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.ff_scale = 0.5 + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block + + self.dropout = nn.Dropout(dropout) + + def forward(self, src: Tensor) -> Tensor: + # convolution module + residual = src + src = self.norm_conv(src) + src = residual + self.dropout(self.conv_module(src)) + + # feed forward module + residual = src + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + + # final normalization + src = self.norm_final(src) + return src + +# for search: SimpleCausalEncoder +def simple_causal_encoder(encoder_layer: nn.Module, + num_layers: int): + return torch.nn.Sequential([copy.deepcopy(encoder_leyer) for _ in range(num_layers)]) + class DiscreteBottleneckConformer(Transformer): @@ -39,8 +670,6 @@ class DiscreteBottleneckConformer(Transformer): num_decoder_layers (int): number of decoder layers dropout (float): dropout rate cnn_module_kernel (int): Kernel size of convolution module - normalize_before (bool): whether to use layer_norm before the first block. - vgg_frontend (bool): whether to use vgg frontend. discrete_bottleneck_pos (int): position in the encoder at which to place the discrete bottleneck (this many encoder layers will precede it) @@ -134,7 +763,7 @@ class DiscreteBottleneckConformer(Transformer): Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). Tensor: Mask tensor of dimension (batch_size, input_length) """ - x = self.encoder_embed(x) + x = self.feat_embed(x) x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) mask = encoder_padding_mask(x.size(0), supervisions) @@ -142,17 +771,21 @@ class DiscreteBottleneckConformer(Transformer): mask = mask.to(x.device) x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) - if self.normalize_before and self.is_espnet_structure: - x = self.after_norm(x) - return x, mask +class ReverseGrad(torch.autograd.Function): + def apply(ctx, x): + return x + def backward(ctx, x_grad): + return -x_grad class DiscreteBottleneck(nn.Module): """ This layer forces its input through an information bottleneck via - a discretization operation with sampling. We use the torch-flow-sampling + a discretization operation with sampling, and allows you to + predict the likelihood of those discretized values. + We use the torch-flow-sampling package for this, to provide a differentiable softmax that should be much better than Gumbel in terms of actually giving an information bottleneck. @@ -183,6 +816,9 @@ class DiscreteBottleneck(nn.Module): min_prob_ratio: For any class whose average softmax output, for a given minibatch, is less than min_prob_ratio times + include_predictor: If true, include the parameters + necessary to predict the likelihoods of the + classes from some kind of input embedding. """ def __init__( self, @@ -191,7 +827,8 @@ class DiscreteBottleneck(nn.Module): num_groups: int, interp_prob: float = 1.0, straight_through_scale: float = 0.333, - min_prob_ratio: float = 0.1 + min_prob_ratio: float = 0.1, + include_predictor: bool = True ): super(DiscreteBottleneck, self).__init__() self.norm_in = nn.LayerNorm(dim) @@ -216,24 +853,73 @@ class DiscreteBottleneck(nn.Module): # (c.f. 'min_prob_ratio'). self.register_buffer('class_offsets', torch.zeros(tot_classes)) - self.linear2 = nn.Linear(tot_classes, dim) + self.linear2 = nn.Linear(tot_classes, dim, bias=False) self.norm_out = nn.LayerNorm(dim) + if include_predictor: + # pred_linear predicts the class probabilities from a predictor + # embedding. + self.pred_linear = nn.Linear(dim, tot_classes) - def forward(self, x: Tensor) -> Tensor: + if self.num_groups > 1: + # We predict the logprobs of each group from the outputs of the + # previous groups. This is done via a masked multiply, where + # the masking operates on blocks. This projects from [all but + # the last group] to [all but the first group], so the diagonal + # of the mask can be 1, not 0, saving compute.. + d = tot_classes - self.classes_per_group + c = self.classes_per_group + self.pred_cross = nn.Parameter(torch.Tensor(d, d)) + # If d == 4 and c == 2, the expression below has the following value + # (treat True as 1 and False as 0). + #tensor([[ True, True, False, False], + # [ True, True, False, False], + # [ True, True, True, True], + # [ True, True, True, True]]) + self.register_buffer('pred_cross_mask', + ((torch.arange(d) // c).unsqueeze(1) >= (torch.arange(d) // c).unsqueeze(0))) + self.reset_parameters() + + def reset_parameters_(self): + if hasattr(self, 'pred_cross'): + torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5)) + + + def forward(self, x: Tensor, need_softmax: bool = False) -> Tuple[Tensor, Tensor, Tensor]: """ Forward computation. Args: x: The input tensor, of shape (S, N, E) where S is the sequence length, N is the batch size and E is the embedding dim. + Returns (embeddding, sampled, softmax), where: + + embedding: of shape (S, N, E) where E is the embedding dimension (`dim` arg + to the constructor), this is the output embedding; it is projected from the + sampled class probabilities. + sampled: of shape (S, N, C) where C is the `tot_classes` to the + constructor, these are the sampled one-hot vectors or interpolations + thereof. They will be needed if we try to predict the discrete values + (e.g. some kind of reverse model). + softmax: A Tensor of shape (S, N, C) if need_softmax is True; else, None. + This is the non-sampled softmax output. We use this as the target when + evaluating the 'reverse' model (predicting the probabilities of these + classes), as we can treat it as an + expectation of the result of sampling -> lower-variance derivatives. + This is unnecessary if straight_through_scale == 1.0, since in that + case it would not affect the backpropagated derivatives. """ x = self.norm_in(x) x = self.linear1(x) x = x + self.class_offsets (S, N, tot_classes) = x.shape + x = x.reshape(S, N, self.num_groups, self.classes_per_group) + # This is a little wasteful since we already compute the softmax + # inside 'flow_sample'. + softmax = x.softmax().reshape(S, N, tot_classes) if need_softmax else None + x = torch_flow_sampling.flow_sample(x, interp_prob=self.interp_prob, straight_through_scale=self.straight_through_scale) @@ -241,6 +927,8 @@ class DiscreteBottleneck(nn.Module): assert x.shape == (S, N, self.num_groups, self.classes_per_group) x = x.reshape(S, N, tot_classes) + sampled = x + if self.training: mean_class_probs = torch.mean(x.detach(), dim=(0,1)) self.class_probs = (self.class_probs * self.class_probs_decay + @@ -248,11 +936,82 @@ class DiscreteBottleneck(nn.Module): prob_floor = self.min_prob_ratio / self.classes_per_group self.class_offsets += (self.class_probs > prob_floor) * self.prob_boost - x = self.linear2(x) - x = self.norm_out(x) - return x + embedding = self.norm_out(self.linear2(x)) + return (embedding, sampled, softmax) + def compute_prob(self, x: Tensor, sampled: Tensor, softmax: Optional[Tensor], + padding_mask: Optional[Tensor], + reverse_gradient: bool = False) -> Tensor: + """ + Compute the total probability of the sampled probabilities, given + some kind of predictor x (which we assume should not have access + to the output on the current frame, but might have access to + those of previous frames). + + x: The predictor tensor, of shape (S, N, E) where S is the + sequence length, N is the batch size and E is the embedding dim + (`dim` arg to __init__()) + sampled: A tensor of shape (S, N, C) where C is the `tot_classes` + to the constructor, containing the sampled probabilities. + softmax: A tensor of shape (S, N, C), this is the "smooth" version + of `sampled`, which we use as the target in order to get + lower-variance derivatives with the same expectation. + If not provided, will default to `sampled`. + padding_mask: Optionally, a boolean tensor of shape (N, S), i.e. + (batch_size, sequence_length), with True in masked positions + that are to be ignored in the sum of probabilities. + + reverse_gradient: If true, negate the gradient that is passed back + to 'x' and to the modules self.pred_linear and pred_cross. + This will be useful in computing a loss function that has + a likelihood term with negative sign (i.e. the self-prediction). + We'll later need negate the gradient one more more time + where we give the input to whatever module generated 'x'. + + Returns a scalar Tensor represnting the total probability. + """ + if reverse_gradient: + sampled = ReverseGrad.apply(sampled) + if softmax is None: + softmax = sampled + else: + softmax = ReverseGrad.apply(softmax) + + logprobs = self.pred_linear(x) + + if self.num_groups > 1: + pred_cross = self.pred_cross * self.pred_cross_mask + t = self.tot_classes + c = self.classes_per_group + + cross_in = sampled[:,:,0:t-c] # all but the last group. Note: we could possibly + # use softmax here, but I was concerned about information + # leakage. + # row index of pred_cross corresponds to output, col to input -> must transpose + # before multiply. + cross_out = torch.matmul(softmax_in, pred_cross.transpose(0, 1)) + # add the output of this matrix multiplication to all but the first + # group in `logprobs`. Each group is predicted based on previous + # groups. + logprobs[:,:,c:] += cross_out + (S, N, C) = logprobs.shape + logprobs = logprobs.reshape(S, N, self.num_groups, self.classes_per_group) + # Normalize the log-probs (so they sum to one) + logprobs = torch.nn.functional.logsoftmax(logprobs, dim=-1) + + if padding_mask is not None: + assert padding_mask.dtype == torch.bool and padding_mask.shape == (N, S) + padding_mask = torch.logical_not(padding_mask).transpose(0, 1).unsqueeze(-1) + # padding_mask.shape == (S, N, E) + tot_prob = (logprobs * softmax * padding_mask).sum() + else: + tot_prob = (logprobs * softmax).sum() + + if reverse_gradient: + tot_prob = ReverseGrad.apply(tot_prob) + return tot_prob + class ConformerEncoderLayer(nn.Module): """ @@ -265,7 +1024,6 @@ class ConformerEncoderLayer(nn.Module): dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). cnn_module_kernel (int): Kernel size of convolution module. - normalize_before: whether to use layer_norm before the first block. Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -281,12 +1039,10 @@ class ConformerEncoderLayer(nn.Module): dim_feedforward: int = 2048, dropout: float = 0.1, cnn_module_kernel: int = 31, - normalize_before: bool = True, - is_espnet_structure: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure + d_model, nhead, dropout=0.0, ) self.feed_forward = nn.Sequential( @@ -320,8 +1076,6 @@ class ConformerEncoderLayer(nn.Module): self.dropout = nn.Dropout(dropout) - self.normalize_before = normalize_before - def forward( self, src: Tensor, @@ -348,18 +1102,14 @@ class ConformerEncoderLayer(nn.Module): # macaron style feed forward module residual = src - if self.normalize_before: - src = self.norm_ff_macaron(src) + src = self.norm_ff_macaron(src) src = residual + self.ff_scale * self.dropout( self.feed_forward_macaron(src) ) - if not self.normalize_before: - src = self.norm_ff_macaron(src) # multi-headed self-attention module residual = src - if self.normalize_before: - src = self.norm_mha(src) + src = self.norm_mha(src) src_att = self.self_attn( src, src, @@ -369,32 +1119,23 @@ class ConformerEncoderLayer(nn.Module): key_padding_mask=src_key_padding_mask, )[0] src = residual + self.dropout(src_att) - if not self.normalize_before: - src = self.norm_mha(src) # convolution module residual = src - if self.normalize_before: - src = self.norm_conv(src) + src = self.norm_conv(src) src = residual + self.dropout(self.conv_module(src)) - if not self.normalize_before: - src = self.norm_conv(src) # feed forward module residual = src - if self.normalize_before: - src = self.norm_ff(src) + src = self.norm_ff(src) src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) - if not self.normalize_before: - src = self.norm_ff(src) - if self.normalize_before: - src = self.norm_final(src) + src = self.norm_final(src) return src -class ConformerEncoder(nn.TransformerEncoder): +class ConformerEncoder(nn.Module): r"""ConformerEncoder is a stack of N encoder layers Args: @@ -410,56 +1151,37 @@ class ConformerEncoder(nn.TransformerEncoder): >>> out = conformer_encoder(src, pos_emb) """ - def __init__( - self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None, - discrete_bottleneck: Optional[nn.Module] = None, - discrete_bottleneck_pos: Optional[int] = None - ) -> None: - super(ConformerEncoder, self).__init__( - encoder_layer=encoder_layer, num_layers=num_layers, norm=norm - ) - self.discrete_bottleneck = discrete_bottleneck - self.discrete_bottleneck_pos = discrete_bottleneck_pos + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + super(ConformerEncoder, self).__init__() + self.layers = torch.nn.ModuleList([copy.deepcopy(encoder_leyer) for _ in range(num_layers)]) def forward( self, - src: Tensor, + x: Tensor, pos_emb: Tensor, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - + Args + x: input of shape (T, N, C), i.e. (seq_len, batch, channels) + pos_emb: positional embedding tensor of shape (1, 2*T-1, C), + attn_mask (optional, likely not used): mask for self-attention of + x to itself, of shape (T, T) + key_padding_mask (optional): mask of shape (N, T), dtype must be bool. + Returns: + Returns a tensor with the same shape as x, i.e. (T, N, C). """ - output = src - - for i, mod in enumerate(self.layers): - if i == self.discrete_bottleneck_pos: - output = self.discrete_bottleneck(output) - output = mod( - output, + for mod in self.layers: + x = mod( + x, pos_emb, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, ) - if self.norm is not None: - output = self.norm(output) + return x - return output class RelPositionalEncoding(torch.nn.Module): @@ -565,7 +1287,6 @@ class RelPositionMultiheadAttention(nn.Module): embed_dim: int, num_heads: int, dropout: float = 0.0, - is_espnet_structure: bool = False, ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim @@ -588,8 +1309,6 @@ class RelPositionMultiheadAttention(nn.Module): self._reset_parameters() - self.is_espnet_structure = is_espnet_structure - def _reset_parameters(self) -> None: nn.init.xavier_uniform_(self.in_proj.weight) nn.init.constant_(self.in_proj.bias, 0.0) @@ -819,9 +1538,6 @@ class RelPositionMultiheadAttention(nn.Module): _b = _b[_start:] v = nn.functional.linear(value, _w, _b) - if not self.is_espnet_structure: - q = q * scaling - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -914,14 +1630,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - if not self.is_espnet_structure: - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) - else: - attn_output_weights = ( - matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, -1 @@ -976,6 +1687,83 @@ class RelPositionMultiheadAttention(nn.Module): else: return attn_output, None +class CausalConvolutionModule(nn.Module): + """Modified from ConvolutionModule from the in Conformer model. + This is a causal version of it (sees only left-context). + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + """ + + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + self.kernel_size = kernel_size + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=0, # We'll pad manually + groups=channels, + bias=bias, + ) + self.norm = nn.BatchNorm1d(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward(self, x: Tensor) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + (B, C, T) = x + padding = self.kernel_size - 1 + x = torch.cat((torch.zeros(B, C, padding, device=x.device, dtype=x.type), x), + dim=2) + x = self.depthwise_conv(x) # <-- This has no padding. + + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) # (batch, channels, time) + + return x.permute(2, 0, 1) + class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model. diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/transformer.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/transformer.py index 191d2d612..5d5213f56 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/transformer.py @@ -939,7 +939,7 @@ def decoder_padding_mask( return ys_mask -def generate_square_subsequent_mask(sz: int) -> torch.Tensor: +def generate_square_subsequent_mask(sz: int, device: torch.device = torch.device('cpu')) -> torch.Tensor: """Generate a square mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). The mask can be used for masked self-attention. @@ -956,7 +956,7 @@ def generate_square_subsequent_mask(sz: int) -> torch.Tensor: Returns: A square mask of dimension (sz, sz) """ - mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = (torch.triu(torch.ones(sz, sz, device=device)) == 1).transpose(0, 1) mask = ( mask.float() .masked_fill(mask == 0, float("-inf"))