From 6fa0f16e0c36bbb89dbb1a87ee665e4d386d415e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Sep 2021 17:31:29 +0800 Subject: [PATCH] Remove reconstruction loss, have randomly averaged CTC loss --- .../ASR/conformer_ctc_bn_2d/conformer.py | 23 +++++++++++++++++++ .../ASR/conformer_ctc_bn_2d/train.py | 14 ++++------- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py index 1508d7c4b..99d928d8c 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py @@ -222,6 +222,11 @@ class BidirectionalConformer(nn.Module): nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) ) + self.bottleneck_ctc_encoder = ConformerEncoder(encoder_layer, num_ctc_encoder_layers) + self.bottleneck_ctc_output_layer = nn.Sequential( + nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + ) + # absolute position encoding, used by various layer types self.abs_pos = PositionalEncoding(d_model, dropout) @@ -449,6 +454,7 @@ class BidirectionalConformer(nn.Module): memory: torch.Tensor, pos_emb: torch.Tensor, memory_key_padding_mask: torch.Tensor, + positive_embed: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Passes the output of forward() through the CTC encoder and the CTC @@ -461,6 +467,8 @@ class BidirectionalConformer(nn.Module): Relative positional embedding tensor: (N, 2*T-1, E) memory_key_padding_mask: The padding mask from forward(), a tensor of bool of shape (N, T) + positive_embed: + Needed only during training, so we can train the bottleneck layer.. Returns: A Tensor with shape [N, T, C] where C is the number of classes @@ -473,6 +481,21 @@ class BidirectionalConformer(nn.Module): x = self.ctc_output_layer(x) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) + + if self.training: + # Randomly interpolate half-and-half with the bottleneck CTC + # encoder, at the frame level + y = self.bottleneck_ctc_encoder(positive_embed, + pos_emb, + key_padding_mask=memory_key_padding_mask) + y = self.bottleneck_ctc_output_layer(y) + y = y.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + y = nn.functional.log_softmax(y, dim=-1) # (N, T, C) + (N, T, C) = y.shape + r = torch.rand(N, T, 1, device=y.device) + x = (y * r) + x - (x * r) + + return x diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/train.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/train.py index 92d0ad505..88fe7ab52 100755 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/train.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/train.py @@ -156,7 +156,7 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_ctc_bn_2d/exp_bidirectional_delay"), + "exp_dir": Path("conformer_ctc_bn_2d/exp_bidirectional_delay_norecon"), "lang_dir": Path("data/lang_bpe"), "feature_dim": 80, "subsampling_factor": 4, # can't be changed @@ -175,8 +175,7 @@ def get_params() -> AttributeDict: "att_scale": 0.5, "reverse_att_scale": 0.2, "ctc_scale": 0.3, - "reconstruction_scale": 0.5, # Scale on log of reconstruction error after discrete bottleneck. - "delay_scale": 2.0, # Scale on difference between current and + "delay_scale": 0.1, # Scale on difference between current and # delayed version of positive_embed. "delay_minibatches": 200, "attention_dim": 512, @@ -476,12 +475,11 @@ def compute_loss( delay_loss = compute_distance(old_positive_embed, positive_embed) num_subsampled_frames = memory.shape[0] * memory.shape[1] - reconstruction_loss = (((positive_embed - memory.detach()) ** 2).sum() / num_subsampled_frames).sqrt() * num_subsampled_frames - ctc_output = mmodel.ctc_encoder_forward(memory, - position_embedding, - memory_mask) + position_embedding, + memory_mask, + positive_embed) # NOTE: We need `encode_supervisions` to sort sequences with @@ -556,7 +554,6 @@ def compute_loss( loss = (params.ctc_scale * ctc_loss + - (params.reconstruction_scale if params.cur_epoch > 0 else 0.1 * params.reconstruction_scale) * reconstruction_loss + params.att_scale * att_loss + (params.reverse_att_scale if params.cur_epoch > 0 else 0.001 * params.reverse_att_scale) * reverse_att_loss) if params.cur_epoch > 0 and params.delay_scale > 0.0: @@ -569,7 +566,6 @@ def compute_loss( # TODO: there are many GPU->CPU transfers here, maybe combine them into one. info['frames'] = supervision_segments[:, 2].sum().item() info['ctc_loss'] = ctc_loss.detach().cpu().item() - info['reconstruction_loss'] = reconstruction_loss.detach().cpu().item() if params.cur_epoch > 0 and params.delay_scale > 0.0: info['delay_loss'] = delay_loss.detach().cpu().item() if params.att_scale != 0.0: