From 8d07ce21859ca7d3dd498dec8c290f901d1604cc Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 29 Sep 2021 13:11:52 +0800 Subject: [PATCH] Update utils.py --- icefall/utils.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/icefall/utils.py b/icefall/utils.py index 23b4dd6c7..1abc77262 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1,4 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) # # See ../../LICENSE for clarification regarding multiple authors # @@ -17,6 +18,7 @@ import argparse import logging +import collections import os import subprocess from collections import defaultdict @@ -29,6 +31,7 @@ import k2 import kaldialign import torch import torch.distributed as dist +from torch.utils.tensorboard import SummaryWriter Pathlike = Union[str, Path] @@ -419,3 +422,73 @@ def write_error_stats( print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) return float(tot_err_rate) + + +class LossRecord(collections.defaultdict): + def __init__(self): + # Passing the type 'int' to the base-class constructor + # makes undefined items default to int() which is zero. + super(LossRecord, self).__init__(int) + + def __add__(self, other: 'LossRecord') -> 'LossRecord': + ans = LossRecord() + for k, v in self.items(): + ans[k] = v + for k, v in other.items(): + ans[k] = ans[k] + v + return ans + + def __mul__(self, alpha: float) -> 'LossRecord': + ans = LossRecord() + for k, v in self.items(): + ans[k] = v * alpha + return ans + + def __str__(self) -> str: + ans = '' + for k, v in self.norm_items(): + norm_value = '%.4g' % v + ans += (str(k) + '=' + str(norm_value) + ', ') + frames = str(self['frames']) + ans += 'over ' + frames + ' frames.' + return ans + + def norm_items(self) -> List[Tuple[str, float]]: + """ + Returns a list of pairs, like: + [('ctc_loss', 0.1), ('att_loss', 0.07)] + """ + num_frames = self['frames'] if 'frames' in self else 1 + ans = [] + for k, v in self.items(): + if k != 'frames': + norm_value = float(v) / num_frames + ans.append((k, norm_value)) + return ans + + def reduce(self, device): + """ + Reduce using torch.distributed, which I believe ensures that + all processes get the total. + """ + keys = sorted(self.keys()) + s = torch.tensor([ float(self[k]) for k in keys ], + device=device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + for k, v in zip(keys, s.cpu().tolist()): + self[k] = v + + def write_summary(self, + tb_writer: SummaryWriter, + prefix: str, + batch_idx: int) -> None: + """Add logging information to a TensorBoard writer. + + Args: + tb_writer: a TensorBoard writer + prefix: a prefix for the name of the loss, e.g. "train/valid_", + or "train/current_" + batch_idx: The current batch index, used as the x-axis of the plot. + """ + for k, v in self.norm_items(): + tb_writer.add_scalar(prefix + k, v, batch_idx)