diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/train.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/train.py index 48a58d96c..670665110 100755 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/train.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/train.py @@ -17,18 +17,21 @@ import argparse +import collections import logging from pathlib import Path +import random # temp.. from shutil import copyfile -from typing import Optional +from typing import Optional, Tuple import k2 import torch +from torch import Tensor import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from conformer import DiscreteBottleneckConformer +from conformer import BidirectionalConformer from lhotse.utils import fix_random_seed from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ @@ -153,7 +156,7 @@ def get_params() -> AttributeDict: "exp_dir": Path("conformer_ctc_bn/exp_gloam_5e-4_0.85_discrete8"), "lang_dir": Path("data/lang_bpe"), "feature_dim": 80, - "subsampling_factor": 4, + "subsampling_factor": 4, # can't be changed "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -166,12 +169,18 @@ def get_params() -> AttributeDict: "reduction": "sum", "use_double_scores": True, "accum_grad": 1, - "att_rate": 0.7, + "att_scale": 0.4, + "reverse_att_scale": 0.4, # ctc_scale == 1.0 - att_scale - reverse_att_scale "attention_dim": 512, "nhead": 8, + "num_trunk_encoder_layers": 12, "num_decoder_layers": 6, - "is_espnet_structure": True, - "mmi_loss": False, + "num_reverse_encoder_layers": 4, + "num_reverse_decoder_layers": 4, + "num_self_predictor_layers": 2, + "discretization_tot_classes": 512, + "discretization_num_groups": 8, + "is_bpe": True, "use_feat_batchnorm": True, "max_lrate": 5.0e-04, "first_decay_epoch": 1, @@ -270,15 +279,83 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) +class LossRecord(collections.defaultdict): + def __init__(self): + # Passing the type 'int' to the base-class constructor + # makes undefined items default to int() which is zero. + super(LossRecord, self).__init__(int) + + def __add__(self, other: LossRecord) -> LossRecord: + ans = LossRecord() + for k, v in self.items(): + ans[k] = v + for k, v in other.items(): + ans[k] = ans[k] + v + return ans + + def __mul__(self, alpha: float) -> LossRecord: + ans = LossRecord() + for k, v in self.items(): + ans[k] = v * alpha + return ans + + + def __str__(self) -> str: + ans = '' + for k, v in self.norm_items(): + norm_value = '%.2g' % v + ans += (str(k) + '=' + str(norm_value) + ', ') + frames = str(self['frames']) + ans += 'over ' + frames + ' frames.' + return ans + + def norm_items(self) -> List[Tuple[string, float]] + """ + Returns a list of pairs, like: + [('ctc_loss', 0.1), ('att_loss', 0.07)] + """ + num_frames = self['frames'] if 'frames' in self else 1 + ans = [] + for k, v in self.items(): + if k != 'frames': + norm_value = float(v) / num_frames + ans.append((k, norm_value)) + + + def reduce(self, device): + """ + Reduce using torch.distributed, which I believe ensures that + all processes get the total. + """ + keys = sorted(self.keys()) + s = torch.tensor([ float(self[k]) for k in keys ], + device=device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + for k, v in zip(keys, s.cpu().tolist()): + self[k] = v + + def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None: + """ + Add logging information to a TensorBoard writer. + tb_writer: a TensorBoard writer + prefix: a prefix for the name of the loss, e.g. "train/valid_", + or "train/current_" + batch_idx: The current batch index, used as the x-axis of the plot. + """ + for k, v in self.norm_items(): + tb_writer.add_scalar(prefix + k, v, batch_idx) + + + def compute_loss( params: AttributeDict, model: nn.Module, batch: dict, graph_compiler: BpeCtcTrainingGraphCompiler, is_training: bool, -): +) -> Tuple[Tensor, LossRecord] """ - Compute CTC loss given the model and its inputs. + Compute loss function (including CTC, attention, and reverse-attention terms). Args: params: @@ -306,9 +383,16 @@ def compute_loss( supervisions = batch["supervisions"] + mmodel = model.module if hasattr(model, "module") else model + with torch.set_grad_enabled(is_training): - nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is [N, T, C] + memory, position_embedding, memory_mask = model(feature, supervisions) + # memory's shape is (N, T, C) + + ctc_output = mmodel.ctc_encoder_forward(memory, + position_embedding, + memory_mask) + # NOTE: We need `encode_supervisions` to sort sequences with # different duration in decreasing order, required by @@ -322,7 +406,7 @@ def compute_loss( decoding_graph = graph_compiler.compile(token_ids) dense_fsa_vec = k2.DenseFsaVec( - nnet_output, + ctc_output, supervision_segments, allow_truncate=params.subsampling_factor - 1, ) @@ -335,38 +419,71 @@ def compute_loss( use_double_scores=params.use_double_scores, ) - if params.att_rate != 0.0: + if params.att_scale != 0.0: with torch.set_grad_enabled(is_training): - if hasattr(model, "module"): - att_loss = model.module.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - else: - att_loss = model.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + att_loss = mmodel.decoder_forward( + memory, + memory_mask, + token_ids=token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) else: - loss = ctc_loss - att_loss = torch.tensor([0]) + att_loss = torch.tensor([0.0]).to(device) - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() + if params.reverse_att_scale != 0.0: + with torch.set_grad_enabled(is_training): + (sampled, softmax, + positive_embed_shifted, + negative_embed_shifted) = mmodel.sample_forward(memory) + + reverse_decoder_logprob = mmodel.reverse_decoder_forward( + positive_embed_shifted, + memory_mask, + sampled, softmax, + token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + padding_id=0) + + self_prediction_logprob = mmodel.self_prediction_forward( + negative_embed_shifted, + memory_mask, + sampled, softmax) + + # Note: reverse_att_loss is the mutual information between + # the word sequence and the frames; it will generally be negative, + # and is to be minimized (i.e. it goes away from zero as we train, + # it does not approach zero). + reverse_att_loss = self_prediction_logprob - reverse_decoder_logprob + + if random.random() < 0.01: + # Will eventually remove this block.. + num_frames = supervision_segments[:, 2].sum().item() + print(f"Self-prediction logprob = {self_prediction_logprob/num_frames}, " + f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}" + f"reverse_att_loss = {reverse_att_loss/num_frames}") else: - params.valid_frames = supervision_segments[:, 2].sum().item() + reverse_att_loss = torch.tensor([0.0]).to(device) + ctc_scale = 1.0 - params.att_scale - params.reverse_att_scale + loss = (ctc_scale * ctc_loss + + params.att_scale * att_loss + + params.reverse_att_scale * reverse_att_loss) assert loss.requires_grad == is_training - return loss, ctc_loss.detach(), att_loss.detach() + info = LossRecord() + # TODO: there are many GPU->CPU transfers here, maybe combine them into one. + info['frames'] = supervision_segments[:, 2].sum().item() + info['ctc_loss'] = ctc_loss.detach().cpu().item() + if params.att_scale != 0.0: + info['att_loss'] = att_loss.detach().cpu().item() + if params.reverse_att_scale != 0.0: + info['reverse_att_loss'] = reverse_att_loss.detach().cpu().item() + info['loss'] = loss.detach().cpu().item() + + + return loss, info except RuntimeError as e: print(f"Runtime error. feature.shape = {feature.shape}, supervisions = {supervisions}") raise e @@ -381,18 +498,13 @@ def compute_validation_loss( graph_compiler: BpeCtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> None: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ +) -> LossRecord: + """Run the validation process. """ model.eval() - tot_loss = 0.0 - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 - tot_frames = 0.0 + tot_loss = LossRecord() for batch_idx, batch in enumerate(valid_dl): - loss, ctc_loss, att_loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, @@ -400,36 +512,18 @@ def compute_validation_loss( is_training=False, ) assert loss.requires_grad is False - assert ctc_loss.requires_grad is False - assert att_loss.requires_grad is False - - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - - tot_ctc_loss += ctc_loss.detach().cpu().item() - tot_att_loss += att_loss.detach().cpu().item() - - tot_frames += params.valid_frames + tot_loss = tot_loss + loss_info if world_size > 1: - s = torch.tensor( - [tot_loss, tot_ctc_loss, tot_att_loss, tot_frames], - device=loss.device, - ) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_ctc_loss = s[1] - tot_att_loss = s[2] - tot_frames = s[3] + tot_loss.reduce(loss.device) - params.valid_loss = tot_loss / tot_frames - params.valid_ctc_loss = tot_ctc_loss / tot_frames - params.valid_att_loss = tot_att_loss / tot_frames - - if params.valid_loss < params.best_valid_loss: + loss_value = tot_loss['loss'] / tot_loss['frames'] + if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss + params.best_valid_loss = loss_value + + return tot_loss + def train_one_epoch( @@ -468,24 +562,20 @@ def train_one_epoch( """ model.train() - tot_loss = 0.0 # sum of losses over all batches - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 + tot_loss = LossInfo() - tot_frames = 0.0 # sum of frames over all batches - params.tot_loss = 0.0 - params.tot_frames = 0.0 for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss, ctc_loss, att_loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True, ) + tot_loss = (tot_loss * (1 + 1 / params.reset_interval)) + loss_info # summary stats. # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -495,75 +585,22 @@ def train_one_epoch( clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - loss_cpu = loss.detach().cpu().item() - ctc_loss_cpu = ctc_loss.detach().cpu().item() - att_loss_cpu = att_loss.detach().cpu().item() + if batch_idx % 10 == 0: - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_ctc_loss += ctc_loss_cpu - tot_att_loss += att_loss_cpu - - params.tot_frames += params.train_frames - params.tot_loss += loss_cpu - - tot_avg_loss = tot_loss / tot_frames - tot_avg_ctc_loss = tot_ctc_loss / tot_frames - tot_avg_att_loss = tot_att_loss / tot_frames + if tb_writer is not None: + loss_info.write_summary(tb_writer, "train/current_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx % params.log_interval == 0: logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, " - f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, " - f"total avg att loss: {tot_avg_att_loss:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" + f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" ) - if tb_writer is not None: - tb_writer.add_scalar( - "train/current_ctc_loss", - ctc_loss_cpu / params.train_frames, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/current_att_loss", - att_loss_cpu / params.train_frames, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/tot_avg_ctc_loss", - tot_avg_ctc_loss, - params.batch_idx_train, - ) - - tb_writer.add_scalar( - "train/tot_avg_att_loss", - tot_avg_att_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, - ) - if batch_idx > 0 and batch_idx % params.reset_interval == 0: - tot_loss = 0.0 # sum of losses over all batches - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 - - tot_frames = 0.0 # sum of frames over all batches if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( + logging.info("Computing validation loss") + valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, @@ -572,32 +609,14 @@ def train_one_epoch( ) model.train() logging.info( - f"Epoch {params.cur_epoch}, " - f"valid ctc loss {params.valid_ctc_loss:.4f}," - f"valid att loss {params.valid_att_loss:.4f}," - f"valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" + f"Epoch {params.cur_epoch}, validation: {valid_info}" ) if tb_writer is not None: - tb_writer.add_scalar( - "train/valid_ctc_loss", - params.valid_ctc_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/valid_att_loss", - params.valid_att_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/valid_loss", - params.valid_loss, - params.batch_idx_train, - ) + valid_info.write_summary(tb_writer, "train/valid_", params.batch_idx_train) - params.train_loss = params.tot_loss / params.tot_frames + loss_value = tot_loss['loss'] / tot_loss['frames'] + params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch params.best_train_loss = params.train_loss @@ -647,17 +666,21 @@ def run(rank, world_size, args): ) logging.info("About to create model") - model = DiscreteBottleneckConformer( + model = BidirectionalConformer( num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, num_classes=num_classes, - subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + num_trunk_encoder_layers=params.num_trunk_encoder_layers, + num_ctc_encoder_layers=params.num_ctc_encoder_layers, num_decoder_layers=params.num_decoder_layers, - vgg_frontend=False, - is_espnet_structure=params.is_espnet_structure, - mmi_loss=params.mmi_loss, - use_feat_batchnorm=params.use_feat_batchnorm, + num_reverse_encoder_layers=params.num_reverse_encoder_layers, + num_reverse_decoder_layers=params.num_reverse_decoder_layers, + num_self_predictor_layers=params.num_self_predictor_layers, + subsampling_factor=params.subsampling_factor, + is_bpe=params.is_bpe, + discretization_tot_classes=params.discretization_tot_clases, + discretization_num_groups=params.discretization_num_groups, ) checkpoints = load_checkpoint_if_available(params=params, model=model)