From 0fa46bf68ac2b46a48869846565dd3047f209ee8 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 29 Sep 2021 12:49:59 +0800 Subject: [PATCH] Update train.py --- egs/yesno/ASR/tdnn/train.py | 83 +++---------------------------------- 1 file changed, 6 insertions(+), 77 deletions(-) diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index 398119569..ae5fe0235 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -14,7 +14,6 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim from torch import Tensor - from asr_datamodule import YesNoAsrDataModule from lhotse.utils import fix_random_seed from model import Tdnn @@ -27,7 +26,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, setup_logger, str2bool +from icefall.utils import AttributeDict, LossRecord, setup_logger, str2bool def get_parser(): @@ -245,72 +244,6 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) -class LossRecord(collections.defaultdict): - def __init__(self): - # Passing the type 'int' to the base-class constructor - # makes undefined items default to int() which is zero. - super(LossRecord, self).__init__(int) - - def __add__(self, other: 'LossRecord') -> 'LossRecord': - ans = LossRecord() - for k, v in self.items(): - ans[k] = v - for k, v in other.items(): - ans[k] = ans[k] + v - return ans - - def __mul__(self, alpha: float) -> 'LossRecord': - ans = LossRecord() - for k, v in self.items(): - ans[k] = v * alpha - return ans - - def __str__(self) -> str: - ans = '' - for k, v in self.norm_items(): - norm_value = '%.4g' % v - ans += (str(k) + '=' + str(norm_value) + ', ') - frames = str(self['frames']) - ans += 'over ' + frames + ' frames.' - return ans - - def norm_items(self) -> List[Tuple[str, float]]: - """ - Returns a list of pairs, like: - [('ctc_loss', 0.1), ('att_loss', 0.07)] - """ - num_frames = self['frames'] if 'frames' in self else 1 - ans = [] - for k, v in self.items(): - if k != 'frames': - norm_value = float(v) / num_frames - ans.append((k, norm_value)) - return ans - - def reduce(self, device): - """ - Reduce using torch.distributed, which I believe ensures that - all processes get the total. - """ - keys = sorted(self.keys()) - s = torch.tensor([ float(self[k]) for k in keys ], - device=device) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - for k, v in zip(keys, s.cpu().tolist()): - self[k] = v - - def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None: - """ - Add logging information to a TensorBoard writer. - tb_writer: a TensorBoard writer - prefix: a prefix for the name of the loss, e.g. "train/valid_", - or "train/current_" - batch_idx: The current batch index, used as the x-axis of the plot. - """ - for k, v in self.norm_items(): - tb_writer.add_scalar(prefix + k, v, batch_idx) - - def compute_loss( params: AttributeDict, model: nn.Module, @@ -376,9 +309,8 @@ def compute_loss( ) assert loss.requires_grad == is_training - + info = LossRecord() - # TODO: there are many GPU->CPU transfers here, maybe combine them into one. info['frames'] = supervision_segments[:, 2].sum().item() info['loss'] = loss.detach().cpu().item() @@ -398,7 +330,7 @@ def compute_validation_loss( model.eval() tot_loss = LossRecord() - + for batch_idx, batch in enumerate(valid_dl): loss, loss_info = compute_loss( params=params, @@ -410,7 +342,7 @@ def compute_validation_loss( assert loss.requires_grad is False tot_loss = tot_loss + loss_info - + if world_size > 1: tot_loss.reduce(loss.device) @@ -458,7 +390,7 @@ def train_one_epoch( Number of nodes in DDP training. If it is 1, DDP is disabled. """ model.train() - + tot_loss = LossRecord() for batch_idx, batch in enumerate(train_dl): @@ -473,10 +405,7 @@ def train_one_epoch( is_training=True, ) # summary stats. - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info optimizer.zero_grad() loss.backward()