From 2213457bd31be68e3b67bde256395f4d11316096 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Sep 2021 11:25:42 +0800 Subject: [PATCH] Initially working version with delay_loss... --- .../ASR/conformer_ctc_bn_2d/conformer.py | 37 ----- .../ASR/conformer_ctc_bn_2d/train.py | 133 +++++++++++++----- 2 files changed, 100 insertions(+), 70 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py index fffcf7df7..1508d7c4b 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py @@ -222,11 +222,6 @@ class BidirectionalConformer(nn.Module): nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) ) - self.bottleneck_ctc_encoder = ConformerEncoder(encoder_layer, num_ctc_encoder_layers) - self.bottleneck_ctc_output_layer = nn.Sequential( - nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) - ) - # absolute position encoding, used by various layer types self.abs_pos = PositionalEncoding(d_model, dropout) @@ -480,38 +475,6 @@ class BidirectionalConformer(nn.Module): x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) return x - def bottleneck_ctc_encoder_forward( - self, - positive_embed: torch.Tensor, - pos_emb: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - ) -> torch.Tensor: - """ - Passes the output of sample_forward() through the CTC "from-bottleneck" - CTC encoder and the CTC - output layer to give the output that can be given to the CTC loss function - - Args: - positive_embed: - One of the outputs of sample_forward(), with shape (T, N, E) - pos_emb: - Relative positional embedding tensor: (N, 2*T-1, E) - memory_key_padding_mask: - The padding mask from forward(), a tensor of bool of shape (N, T) - - Returns: - A Tensor with shape [N, T, C] where C is the number of classes - (e.g. number of phones or word pieces). Contains normalized - log-probabilities. - """ - x = self.bottleneck_ctc_encoder(positive_embed, - pos_emb, - key_padding_mask=memory_key_padding_mask) - x = self.bottleneck_ctc_output_layer(x) - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) - return x - def self_prediction_forward( self, diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/train.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/train.py index 3eec8d9b4..92d0ad505 100755 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/train.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/train.py @@ -20,6 +20,7 @@ import argparse import collections +import copy import logging from pathlib import Path import random # temp.. @@ -155,7 +156,7 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_ctc_bn_2d/exp_bidirectional_2"), + "exp_dir": Path("conformer_ctc_bn_2d/exp_bidirectional_delay"), "lang_dir": Path("data/lang_bpe"), "feature_dim": 80, "subsampling_factor": 4, # can't be changed @@ -171,9 +172,13 @@ def get_params() -> AttributeDict: "reduction": "sum", "use_double_scores": True, "accum_grad": 1, - "att_scale": 0.3, + "att_scale": 0.5, "reverse_att_scale": 0.2, - "bottleneck_ctc_scale": 0.2, # ctc_scale == 1.0 - att_scale - reverse_att_scale - bottleneck_ctc_scale + "ctc_scale": 0.3, + "reconstruction_scale": 0.5, # Scale on log of reconstruction error after discrete bottleneck. + "delay_scale": 2.0, # Scale on difference between current and + # delayed version of positive_embed. + "delay_minibatches": 200, "attention_dim": 512, "nhead": 8, "num_trunk_encoder_layers": 12, @@ -349,6 +354,69 @@ class LossRecord(collections.defaultdict): tb_writer.add_scalar(prefix + k, v, batch_idx) +def get_delayed_model(model: nn.Module, + params: AttributeDict) -> nn.Module: + if hasattr(model, "module"): + model = model.module + delay_minibatches = params.delay_minibatches + + cur_batch_idx = params.batch_idx_train + + try: + # hasattr doesn't seem to work for this... use try-except to test if it + # has the attribute. + _ = params.cur_delayed_model + except KeyError: + params.cur_delayed_model = copy.deepcopy(model) + params.prev_delayed_model = params.cur_delayed_model + params.cur_delayed_batch_idx = cur_batch_idx + + if params.cur_delayed_batch_idx <= cur_batch_idx - delay_minibatches: + params.cur_delayed_batch_idx = cur_batch_idx + params.prev_delayed_model = params.cur_delayed_model + params.cur_delayed_model = copy.deepcopy(model) + + return params.prev_delayed_model + + +def compute_distance(feats1, feats2): + """ + Assumes that feats1 and feats2 are some kind of features with (N, T, C) or (T, N, C) layout. + Computes a distance between them, that will have the property that its derivative w.r.t. + feats2 will be orthogonal to feats2. This will avoid any direct pressure for feats2 to + grow or shrink. (We assume that feats1 is without grad). Let K = T * N (the total + number of frames. Then the returned value is half the total, over the frames, of the + log of the (average squared distance, per frame), between feats1 and (alpha * feats2), + for optimally chosen alpha, i.e.: + + ans = K * (((feats1 - alpha * feats2) ** 2).sum() / K).sqrt() + + [the minimum of that result, for any alpha]. d/d alpha of: + + ((feats1 - alpha * feats2) * (feats1 - alpha * feats2)).sum() + + Expanded in terms of alpha, this is: + (feats1**2).sum() + alpha**2 *(feats2**2).sum() - 2 * alpha * (feats1*feats2).sum() + and d/dalpha of this is: + 2 * alpha *(feats2**2).sum() - 2 * (feats1*feats2).sum() + so alpha = (feats1*feats2).sum() / (feats2**2).sum() + + """ + feats1_prod = (feats1 ** 2).sum() + feats2_prod = (feats2 ** 2).sum() + cross_prod = (feats2 * feats1).sum() + alpha = cross_prod.detach() / feats2_prod.detach() + if random.random() < 0.01: + logging.info(f"compute_distance: alpha = {alpha.to('cpu').item()}") + K = feats1.shape[0] * feats1.shape[1] + avg_dist = ((feats1_prod + (alpha**2 * feats2_prod) - 2 * alpha * cross_prod) / K) + if avg_dist <= 0.0: + avg_dist = torch.tensor([0.0], device=feats1.device) + else: + avg_dist = avg_dist.sqrt() + return K * avg_dist + + def compute_loss( params: AttributeDict, @@ -388,10 +456,28 @@ def compute_loss( mmodel = model.module if hasattr(model, "module") else model + if params.cur_epoch > 0 and params.delay_scale > 0.0: + with torch.no_grad(): + delayed_model = get_delayed_model(model, params) + with torch.random.fork_rng(devices=[device], enabled=True): + (old_memory, _, _) = delayed_model(feature, supervisions) + (_, _, old_positive_embed, _, _) = delayed_model.sample_forward(old_memory) + + with torch.set_grad_enabled(is_training): memory, position_embedding, memory_mask = model(feature, supervisions) # memory's shape is (N, T, C) + (sampled, softmax, positive_embed, + positive_embed_shifted, + negative_embed_shifted) = mmodel.sample_forward(memory) + + if params.cur_epoch > 0 and params.delay_scale > 0.0: + delay_loss = compute_distance(old_positive_embed, positive_embed) + + num_subsampled_frames = memory.shape[0] * memory.shape[1] + reconstruction_loss = (((positive_embed - memory.detach()) ** 2).sum() / num_subsampled_frames).sqrt() * num_subsampled_frames + ctc_output = mmodel.ctc_encoder_forward(memory, position_embedding, @@ -437,13 +523,6 @@ def compute_loss( if params.reverse_att_scale != 0.0: with torch.set_grad_enabled(is_training): - (sampled, softmax, positive_embed, - positive_embed_shifted, - negative_embed_shifted) = mmodel.sample_forward(memory) - - #if True: # TEMP - # positive_embed_shifted = torch.randn_like(positive_embed_shifted) - # negative_embed_shifted = positive_embed_shifted reverse_decoder_logprob = mmodel.reverse_decoder_forward( positive_embed_shifted, @@ -472,39 +551,27 @@ def compute_loss( f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}, " f"reverse_att_loss = {reverse_att_loss/num_frames}") - bottleneck_ctc_output = mmodel.bottleneck_ctc_encoder_forward(positive_embed, - position_embedding, - memory_mask) - - dense_fsa_vec = k2.DenseFsaVec( - bottleneck_ctc_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - bottleneck_ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) else: reverse_att_loss = torch.tensor([0.0]).to(device) - bottleneck_ctc_loss = torch.tensor([0.0]).to(device) - ctc_scale = 1.0 - params.att_scale - params.reverse_att_scale - params.bottleneck_ctc_scale - loss = (ctc_scale * ctc_loss + - params.bottleneck_ctc_scale * bottleneck_ctc_loss + + + loss = (params.ctc_scale * ctc_loss + + (params.reconstruction_scale if params.cur_epoch > 0 else 0.1 * params.reconstruction_scale) * reconstruction_loss + params.att_scale * att_loss + - (params.reverse_att_scale if params.cur_epoch > 0 else 0.01 * params.reverse_att_scale) * reverse_att_loss) + (params.reverse_att_scale if params.cur_epoch > 0 else 0.001 * params.reverse_att_scale) * reverse_att_loss) + if params.cur_epoch > 0 and params.delay_scale > 0.0: + loss = loss + params.delay_scale * delay_loss + + assert loss.requires_grad == is_training info = LossRecord() # TODO: there are many GPU->CPU transfers here, maybe combine them into one. info['frames'] = supervision_segments[:, 2].sum().item() info['ctc_loss'] = ctc_loss.detach().cpu().item() - info['bottleneck_ctc_loss'] = bottleneck_ctc_loss.detach().cpu().item() + info['reconstruction_loss'] = reconstruction_loss.detach().cpu().item() + if params.cur_epoch > 0 and params.delay_scale > 0.0: + info['delay_loss'] = delay_loss.detach().cpu().item() if params.att_scale != 0.0: info['att_loss'] = att_loss.detach().cpu().item() if params.reverse_att_scale != 0.0: