From 6f8b7b9c3b5c588cc2ef4dcf2e64f2b78813608d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 21 Sep 2021 21:52:17 +0800 Subject: [PATCH] First version that seems to be converging OK... --- .../ASR/conformer_ctc_bn_2d/conformer.py | 44 +++++++++++++++++-- .../ASR/conformer_ctc_bn_2d/train.py | 38 +++++++++++++--- 2 files changed, 74 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py index 8f7c9a7d5..fffcf7df7 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py @@ -222,6 +222,11 @@ class BidirectionalConformer(nn.Module): nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) ) + self.bottleneck_ctc_encoder = ConformerEncoder(encoder_layer, num_ctc_encoder_layers) + self.bottleneck_ctc_output_layer = nn.Sequential( + nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + ) + # absolute position encoding, used by various layer types self.abs_pos = PositionalEncoding(d_model, dropout) @@ -354,7 +359,8 @@ class BidirectionalConformer(nn.Module): Given the "memory" from forward(), run the sample_and_redict module. See documentation for forward() of class SampleAndPredict for more info. - Returns (sampled, softmax, positive_embed_shifted, negative_embed_shifted), + Returns (sampled, softmax, positive_embed, positive_embed_shifted, + negative_embed_shifted), where positive_embed_shifted, for instance, is positive_embed shifted by one so that positive_embed_shifted[t] == positive_embed[t-1], as in: (T, N, E) = positive_embed.shape @@ -368,7 +374,7 @@ class BidirectionalConformer(nn.Module): positive_embed_shifted = torch.cat((zeros, positive_embed[:-1,:,:]), dim=0) negative_embed_shifted = torch.cat((zeros, negative_embed[:-1,:,:]), dim=0) - return (sampled, softmax, positive_embed_shifted, negative_embed_shifted) + return (sampled, softmax, positive_embed, positive_embed_shifted, negative_embed_shifted) def decoder_forward( self, @@ -451,7 +457,7 @@ class BidirectionalConformer(nn.Module): ) -> torch.Tensor: """ Passes the output of forward() through the CTC encoder and the CTC - output to give the output that can be given to the CTC loss function + output layer to give the output that can be given to the CTC loss function Args: memory: @@ -474,6 +480,38 @@ class BidirectionalConformer(nn.Module): x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) return x + def bottleneck_ctc_encoder_forward( + self, + positive_embed: torch.Tensor, + pos_emb: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Passes the output of sample_forward() through the CTC "from-bottleneck" + CTC encoder and the CTC + output layer to give the output that can be given to the CTC loss function + + Args: + positive_embed: + One of the outputs of sample_forward(), with shape (T, N, E) + pos_emb: + Relative positional embedding tensor: (N, 2*T-1, E) + memory_key_padding_mask: + The padding mask from forward(), a tensor of bool of shape (N, T) + + Returns: + A Tensor with shape [N, T, C] where C is the number of classes + (e.g. number of phones or word pieces). Contains normalized + log-probabilities. + """ + x = self.bottleneck_ctc_encoder(positive_embed, + pos_emb, + key_padding_mask=memory_key_padding_mask) + x = self.bottleneck_ctc_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) + return x + def self_prediction_forward( self, diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/train.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/train.py index 16e8baea1..3eec8d9b4 100755 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/train.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/train.py @@ -171,8 +171,9 @@ def get_params() -> AttributeDict: "reduction": "sum", "use_double_scores": True, "accum_grad": 1, - "att_scale": 0.6, - "reverse_att_scale": 0.1, # ctc_scale == 1.0 - att_scale - reverse_att_scale + "att_scale": 0.3, + "reverse_att_scale": 0.2, + "bottleneck_ctc_scale": 0.2, # ctc_scale == 1.0 - att_scale - reverse_att_scale - bottleneck_ctc_scale "attention_dim": 512, "nhead": 8, "num_trunk_encoder_layers": 12, @@ -391,6 +392,7 @@ def compute_loss( memory, position_embedding, memory_mask = model(feature, supervisions) # memory's shape is (N, T, C) + ctc_output = mmodel.ctc_encoder_forward(memory, position_embedding, memory_mask) @@ -435,10 +437,14 @@ def compute_loss( if params.reverse_att_scale != 0.0: with torch.set_grad_enabled(is_training): - (sampled, softmax, + (sampled, softmax, positive_embed, positive_embed_shifted, negative_embed_shifted) = mmodel.sample_forward(memory) + #if True: # TEMP + # positive_embed_shifted = torch.randn_like(positive_embed_shifted) + # negative_embed_shifted = positive_embed_shifted + reverse_decoder_logprob = mmodel.reverse_decoder_forward( positive_embed_shifted, memory_mask, @@ -465,19 +471,40 @@ def compute_loss( print(f"Self-prediction logprob = {self_prediction_logprob/num_frames}, " f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}, " f"reverse_att_loss = {reverse_att_loss/num_frames}") + + bottleneck_ctc_output = mmodel.bottleneck_ctc_encoder_forward(positive_embed, + position_embedding, + memory_mask) + + dense_fsa_vec = k2.DenseFsaVec( + bottleneck_ctc_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + bottleneck_ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) else: reverse_att_loss = torch.tensor([0.0]).to(device) + bottleneck_ctc_loss = torch.tensor([0.0]).to(device) - ctc_scale = 1.0 - params.att_scale - params.reverse_att_scale + ctc_scale = 1.0 - params.att_scale - params.reverse_att_scale - params.bottleneck_ctc_scale loss = (ctc_scale * ctc_loss + + params.bottleneck_ctc_scale * bottleneck_ctc_loss + params.att_scale * att_loss + - params.reverse_att_scale * reverse_att_loss) + (params.reverse_att_scale if params.cur_epoch > 0 else 0.01 * params.reverse_att_scale) * reverse_att_loss) assert loss.requires_grad == is_training info = LossRecord() # TODO: there are many GPU->CPU transfers here, maybe combine them into one. info['frames'] = supervision_segments[:, 2].sum().item() info['ctc_loss'] = ctc_loss.detach().cpu().item() + info['bottleneck_ctc_loss'] = bottleneck_ctc_loss.detach().cpu().item() if params.att_scale != 0.0: info['att_loss'] = att_loss.detach().cpu().item() if params.reverse_att_scale != 0.0: @@ -709,6 +736,7 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): optimizer.set_epoch(epoch) # specific to Gloam train_dl.sampler.set_epoch(epoch) + params.cur_epoch = epoch cur_lr = optimizer._rate if tb_writer is not None: