From e74e75acc639763d83dd69aa2d2962014dafb1dc Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 29 Sep 2021 10:08:38 +0800 Subject: [PATCH] Use LossRecord to record and print loss for the training process --- egs/librispeech/ASR/conformer_ctc/train.py | 255 ++++++++++----------- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 177 ++++++++------ egs/yesno/ASR/tdnn/train.py | 168 +++++++++----- 3 files changed, 337 insertions(+), 263 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 80b2d924a..fcb895394 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) +# Wei Kang +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -18,16 +19,21 @@ import argparse +import collections +import copy import logging from pathlib import Path from shutil import copyfile -from typing import Optional +from typing import Optional, Tuple, List + import k2 import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn +from torch import Tensor + from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from lhotse.utils import fix_random_seed @@ -281,13 +287,80 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) +class LossRecord(collections.defaultdict): + def __init__(self): + # Passing the type 'int' to the base-class constructor + # makes undefined items default to int() which is zero. + super(LossRecord, self).__init__(int) + + def __add__(self, other: 'LossRecord') -> 'LossRecord': + ans = LossRecord() + for k, v in self.items(): + ans[k] = v + for k, v in other.items(): + ans[k] = ans[k] + v + return ans + + def __mul__(self, alpha: float) -> 'LossRecord': + ans = LossRecord() + for k, v in self.items(): + ans[k] = v * alpha + return ans + + + def __str__(self) -> str: + ans = '' + for k, v in self.norm_items(): + norm_value = '%.4g' % v + ans += (str(k) + '=' + str(norm_value) + ', ') + frames = str(self['frames']) + ans += 'over ' + frames + ' frames.' + return ans + + def norm_items(self) -> List[Tuple[str, float]]: + """ + Returns a list of pairs, like: + [('ctc_loss', 0.1), ('att_loss', 0.07)] + """ + num_frames = self['frames'] if 'frames' in self else 1 + ans = [] + for k, v in self.items(): + if k != 'frames': + norm_value = float(v) / num_frames + ans.append((k, norm_value)) + return ans + + def reduce(self, device): + """ + Reduce using torch.distributed, which I believe ensures that + all processes get the total. + """ + keys = sorted(self.keys()) + s = torch.tensor([ float(self[k]) for k in keys ], + device=device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + for k, v in zip(keys, s.cpu().tolist()): + self[k] = v + + def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None: + """ + Add logging information to a TensorBoard writer. + tb_writer: a TensorBoard writer + prefix: a prefix for the name of the loss, e.g. "train/valid_", + or "train/current_" + batch_idx: The current batch index, used as the x-axis of the plot. + """ + for k, v in self.norm_items(): + tb_writer.add_scalar(prefix + k, v, batch_idx) + + def compute_loss( params: AttributeDict, model: nn.Module, batch: dict, graph_compiler: BpeCtcTrainingGraphCompiler, is_training: bool, -): +) -> Tuple[Tensor, LossRecord]: """ Compute CTC loss given the model and its inputs. @@ -367,15 +440,18 @@ def compute_loss( loss = ctc_loss att_loss = torch.tensor([0]) - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() - else: - params.valid_frames = supervision_segments[:, 2].sum().item() - assert loss.requires_grad == is_training - return loss, ctc_loss.detach(), att_loss.detach() + info = LossRecord() + # TODO: there are many GPU->CPU transfers here, maybe combine them into one. + info['frames'] = supervision_segments[:, 2].sum().item() + info['ctc_loss'] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info['att_loss'] = att_loss.detach().cpu().item() + + info['loss'] = loss.detach().cpu().item() + + return loss, info def compute_validation_loss( @@ -384,18 +460,14 @@ def compute_validation_loss( graph_compiler: BpeCtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> None: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ +) -> LossRecord: + """Run the validation process.""" model.eval() - tot_loss = 0.0 - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 - tot_frames = 0.0 + tot_loss = LossRecord() + for batch_idx, batch in enumerate(valid_dl): - loss, ctc_loss, att_loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, @@ -403,36 +475,17 @@ def compute_validation_loss( is_training=False, ) assert loss.requires_grad is False - assert ctc_loss.requires_grad is False - assert att_loss.requires_grad is False - - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - - tot_ctc_loss += ctc_loss.detach().cpu().item() - tot_att_loss += att_loss.detach().cpu().item() - - tot_frames += params.valid_frames + tot_loss = tot_loss + loss_info if world_size > 1: - s = torch.tensor( - [tot_loss, tot_ctc_loss, tot_att_loss, tot_frames], - device=loss.device, - ) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_ctc_loss = s[1] - tot_att_loss = s[2] - tot_frames = s[3] + tot_loss.reduce(loss.device) - params.valid_loss = tot_loss / tot_frames - params.valid_ctc_loss = tot_ctc_loss / tot_frames - params.valid_att_loss = tot_att_loss / tot_frames - - if params.valid_loss < params.best_valid_loss: + loss_value = tot_loss['loss'] / tot_loss['frames'] + if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss + params.best_valid_loss = loss_value + + return tot_loss def train_one_epoch( @@ -471,24 +524,21 @@ def train_one_epoch( """ model.train() - tot_loss = 0.0 # sum of losses over all batches - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 + tot_loss = LossRecord() - tot_frames = 0.0 # sum of frames over all batches - params.tot_loss = 0.0 - params.tot_frames = 0.0 for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss, ctc_loss, att_loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True, ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -498,75 +548,21 @@ def train_one_epoch( clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - loss_cpu = loss.detach().cpu().item() - ctc_loss_cpu = ctc_loss.detach().cpu().item() - att_loss_cpu = att_loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_ctc_loss += ctc_loss_cpu - tot_att_loss += att_loss_cpu - - params.tot_frames += params.train_frames - params.tot_loss += loss_cpu - - tot_avg_loss = tot_loss / tot_frames - tot_avg_ctc_loss = tot_ctc_loss / tot_frames - tot_avg_att_loss = tot_att_loss / tot_frames - if batch_idx % params.log_interval == 0: logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, " - f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, " - f"total avg att loss: {tot_avg_att_loss:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" + f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}]" + f"tot_loss[{tot_loss}], batch size: {batch_size}" ) + if batch_idx % 10 == 0: + if tb_writer is not None: - tb_writer.add_scalar( - "train/current_ctc_loss", - ctc_loss_cpu / params.train_frames, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/current_att_loss", - att_loss_cpu / params.train_frames, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/tot_avg_ctc_loss", - tot_avg_ctc_loss, - params.batch_idx_train, - ) - - tb_writer.add_scalar( - "train/tot_avg_att_loss", - tot_avg_att_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, - ) - if batch_idx > 0 and batch_idx % params.reset_interval == 0: - tot_loss = 0.0 # sum of losses over all batches - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 - - tot_frames = 0.0 # sum of frames over all batches + loss_info.write_summary(tb_writer, "train/current_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( + logging.info("Computing validation loss") + valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, @@ -575,32 +571,13 @@ def train_one_epoch( ) model.train() logging.info( - f"Epoch {params.cur_epoch}, " - f"valid ctc loss {params.valid_ctc_loss:.4f}," - f"valid att loss {params.valid_att_loss:.4f}," - f"valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) + f"Epoch {params.cur_epoch}, validation: {valid_info}" + ) if tb_writer is not None: - tb_writer.add_scalar( - "train/valid_ctc_loss", - params.valid_ctc_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/valid_att_loss", - params.valid_att_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/valid_loss", - params.valid_loss, - params.batch_idx_train, - ) - - params.train_loss = params.tot_loss / params.tot_frames + valid_info.write_summary(tb_writer, "train/valid_", params.batch_idx_train) + loss_value = tot_loss['loss'] / tot_loss['frames'] + params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch params.best_train_loss = params.train_loss @@ -739,4 +716,4 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 695ee5130..babfd07a9 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -18,9 +19,10 @@ import argparse import logging +import collections from pathlib import Path from shutil import copyfile -from typing import Optional +from typing import Optional, Tuple, List import k2 import torch @@ -28,6 +30,8 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim +from torch import Tensor + from asr_datamodule import LibriSpeechAsrDataModule from lhotse.utils import fix_random_seed from model import TdnnLstm @@ -260,6 +264,71 @@ def save_checkpoint( best_valid_filename = params.exp_dir / "best-valid-loss.pt" copyfile(src=filename, dst=best_valid_filename) +class LossRecord(collections.defaultdict): + def __init__(self): + # Passing the type 'int' to the base-class constructor + # makes undefined items default to int() which is zero. + super(LossRecord, self).__init__(int) + + def __add__(self, other: 'LossRecord') -> 'LossRecord': + ans = LossRecord() + for k, v in self.items(): + ans[k] = v + for k, v in other.items(): + ans[k] = ans[k] + v + return ans + + def __mul__(self, alpha: float) -> 'LossRecord': + ans = LossRecord() + for k, v in self.items(): + ans[k] = v * alpha + return ans + + def __str__(self) -> str: + ans = '' + for k, v in self.norm_items(): + norm_value = '%.4g' % v + ans += (str(k) + '=' + str(norm_value) + ', ') + frames = str(self['frames']) + ans += 'over ' + frames + ' frames.' + return ans + + def norm_items(self) -> List[Tuple[str, float]]: + """ + Returns a list of pairs, like: + [('ctc_loss', 0.1), ('att_loss', 0.07)] + """ + num_frames = self['frames'] if 'frames' in self else 1 + ans = [] + for k, v in self.items(): + if k != 'frames': + norm_value = float(v) / num_frames + ans.append((k, norm_value)) + return ans + + def reduce(self, device): + """ + Reduce using torch.distributed, which I believe ensures that + all processes get the total. + """ + keys = sorted(self.keys()) + s = torch.tensor([ float(self[k]) for k in keys ], + device=device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + for k, v in zip(keys, s.cpu().tolist()): + self[k] = v + + def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None: + """ + Add logging information to a TensorBoard writer. + tb_writer: a TensorBoard writer + prefix: a prefix for the name of the loss, e.g. "train/valid_", + or "train/current_" + batch_idx: The current batch index, used as the x-axis of the plot. + """ + for k, v in self.norm_items(): + tb_writer.add_scalar(prefix + k, v, batch_idx) + def compute_loss( params: AttributeDict, @@ -267,7 +336,7 @@ def compute_loss( batch: dict, graph_compiler: CtcTrainingGraphCompiler, is_training: bool, -): +) -> Tuple[Tensor, LossRecord]: """ Compute CTC loss given the model and its inputs. @@ -324,13 +393,12 @@ def compute_loss( assert loss.requires_grad == is_training - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() - else: - params.valid_frames = supervision_segments[:, 2].sum().item() + info = LossRecord() + # TODO: there are many GPU->CPU transfers here, maybe combine them into one. + info['frames'] = supervision_segments[:, 2].sum().item() + info['loss'] = loss.detach().cpu().item() - return loss + return loss, info def compute_validation_loss( @@ -339,16 +407,16 @@ def compute_validation_loss( graph_compiler: CtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> None: +) -> LossRecord: """Run the validation process. The validation loss is saved in `params.valid_loss`. """ model.eval() - tot_loss = 0.0 - tot_frames = 0.0 + tot_loss = LossRecord() + for batch_idx, batch in enumerate(valid_dl): - loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, @@ -357,22 +425,18 @@ def compute_validation_loss( ) assert loss.requires_grad is False - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - tot_frames += params.valid_frames - + tot_loss = tot_loss + loss_info + if world_size > 1: - s = torch.tensor([tot_loss, tot_frames], device=loss.device) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_frames = s[1] + tot_loss.reduce(loss.device) - params.valid_loss = tot_loss / tot_frames + loss_value = tot_loss['loss'] / tot_loss['frames'] - if params.valid_loss < params.best_valid_loss: + if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss + params.best_valid_loss = loss_value + + return tot_loss def train_one_epoch( @@ -411,23 +475,21 @@ def train_one_epoch( """ model.train() - tot_loss = 0.0 # reset after params.reset_interval of batches - tot_frames = 0.0 # reset after params.reset_interval of batches - - params.tot_loss = 0.0 - params.tot_frames = 0.0 + tot_loss = LossRecord() for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True, ) + # summary stats. + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -437,41 +499,19 @@ def train_one_epoch( clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - loss_cpu = loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_avg_loss = tot_loss / tot_frames - - params.tot_frames += params.train_frames - params.tot_loss += loss_cpu - if batch_idx % params.log_interval == 0: logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" + f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}]" + f"tot_loss[{tot_loss}], batch size: {batch_size}" ) + if batch_idx % 10 == 0: + if tb_writer is not None: - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, - ) - - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, - ) - - if batch_idx > 0 and batch_idx % params.reset_interval == 0: - tot_loss = 0 - tot_frames = 0 + loss_info.write_summary(tb_writer, "train/current_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( + valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, @@ -480,12 +520,17 @@ def train_one_epoch( ) model.train() logging.info( - f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) + f"Epoch {params.cur_epoch}, validation {valid_info}" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, + "train/valid_", + params.batch_idx_train, + ) - params.train_loss = params.tot_loss / params.tot_frames + loss_value = tot_loss['loss'] / tot_loss['frames'] + params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch @@ -613,4 +658,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index 0f5506d38..398119569 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -2,9 +2,10 @@ import argparse import logging +import collections from pathlib import Path from shutil import copyfile -from typing import Optional +from typing import Optional, Tuple, List import k2 import torch @@ -12,6 +13,8 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim +from torch import Tensor + from asr_datamodule import YesNoAsrDataModule from lhotse.utils import fix_random_seed from model import Tdnn @@ -122,6 +125,8 @@ def get_params() -> AttributeDict: - valid_interval: Run validation if batch_idx % valid_interval` is 0 + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + - beam_size: It is used in k2.ctc_loss - reduction: It is used in k2.ctc_loss @@ -142,6 +147,7 @@ def get_params() -> AttributeDict: "best_valid_epoch": -1, "batch_idx_train": 0, "log_interval": 10, + "reset_interval": 20, "valid_interval": 10, "beam_size": 10, "reduction": "sum", @@ -239,13 +245,79 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) +class LossRecord(collections.defaultdict): + def __init__(self): + # Passing the type 'int' to the base-class constructor + # makes undefined items default to int() which is zero. + super(LossRecord, self).__init__(int) + + def __add__(self, other: 'LossRecord') -> 'LossRecord': + ans = LossRecord() + for k, v in self.items(): + ans[k] = v + for k, v in other.items(): + ans[k] = ans[k] + v + return ans + + def __mul__(self, alpha: float) -> 'LossRecord': + ans = LossRecord() + for k, v in self.items(): + ans[k] = v * alpha + return ans + + def __str__(self) -> str: + ans = '' + for k, v in self.norm_items(): + norm_value = '%.4g' % v + ans += (str(k) + '=' + str(norm_value) + ', ') + frames = str(self['frames']) + ans += 'over ' + frames + ' frames.' + return ans + + def norm_items(self) -> List[Tuple[str, float]]: + """ + Returns a list of pairs, like: + [('ctc_loss', 0.1), ('att_loss', 0.07)] + """ + num_frames = self['frames'] if 'frames' in self else 1 + ans = [] + for k, v in self.items(): + if k != 'frames': + norm_value = float(v) / num_frames + ans.append((k, norm_value)) + return ans + + def reduce(self, device): + """ + Reduce using torch.distributed, which I believe ensures that + all processes get the total. + """ + keys = sorted(self.keys()) + s = torch.tensor([ float(self[k]) for k in keys ], + device=device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + for k, v in zip(keys, s.cpu().tolist()): + self[k] = v + + def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None: + """ + Add logging information to a TensorBoard writer. + tb_writer: a TensorBoard writer + prefix: a prefix for the name of the loss, e.g. "train/valid_", + or "train/current_" + batch_idx: The current batch index, used as the x-axis of the plot. + """ + for k, v in self.norm_items(): + tb_writer.add_scalar(prefix + k, v, batch_idx) + + def compute_loss( params: AttributeDict, model: nn.Module, batch: dict, graph_compiler: CtcTrainingGraphCompiler, is_training: bool, -): +) -> Tuple[Tensor, LossRecord]: """ Compute CTC loss given the model and its inputs. @@ -304,14 +376,13 @@ def compute_loss( ) assert loss.requires_grad == is_training + + info = LossRecord() + # TODO: there are many GPU->CPU transfers here, maybe combine them into one. + info['frames'] = supervision_segments[:, 2].sum().item() + info['loss'] = loss.detach().cpu().item() - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() - else: - params.valid_frames = supervision_segments[:, 2].sum().item() - - return loss + return loss, info def compute_validation_loss( @@ -320,16 +391,16 @@ def compute_validation_loss( graph_compiler: CtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> None: +) -> LossRecord: """Run the validation process. The validation loss is saved in `params.valid_loss`. """ model.eval() - tot_loss = 0.0 - tot_frames = 0.0 + tot_loss = LossRecord() + for batch_idx, batch in enumerate(valid_dl): - loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, @@ -338,22 +409,18 @@ def compute_validation_loss( ) assert loss.requires_grad is False - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - tot_frames += params.valid_frames - + tot_loss = tot_loss + loss_info + if world_size > 1: - s = torch.tensor([tot_loss, tot_frames], device=loss.device) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_frames = s[1] + tot_loss.reduce(loss.device) - params.valid_loss = tot_loss / tot_frames + loss_value = tot_loss['loss'] / tot_loss['frames'] - if params.valid_loss < params.best_valid_loss: + if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss + params.best_valid_loss = loss_value + + return tot_loss def train_one_epoch( @@ -391,20 +458,22 @@ def train_one_epoch( Number of nodes in DDP training. If it is 1, DDP is disabled. """ model.train() + + tot_loss = LossRecord() - tot_loss = 0.0 # sum of losses over all batches - tot_frames = 0.0 # sum of frames over all batches for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True, ) + # summary stats. + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -414,35 +483,19 @@ def train_one_epoch( clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - loss_cpu = loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_avg_loss = tot_loss / tot_frames - if batch_idx % params.log_interval == 0: logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" + f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}]" + f"tot_loss[{tot_loss}], batch size: {batch_size}" ) + if batch_idx % 10 == 0: if tb_writer is not None: - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, - ) - - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, - ) + loss_info.write_summary(tb_writer, "train/current_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( + valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, @@ -451,18 +504,17 @@ def train_one_epoch( ) model.train() logging.info( - f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) + f"Epoch {params.cur_epoch}, validation {valid_info}" + ) if tb_writer is not None: - tb_writer.add_scalar( - "train/valid_loss", - params.valid_loss, + valid_info.write_summary( + tb_writer, + "train/valid_", params.batch_idx_train, ) - params.train_loss = tot_loss / tot_frames + loss_value = tot_loss['loss'] / tot_loss['frames'] + params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch