mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Update utils.py
This commit is contained in:
parent
1c0792796b
commit
8d07ce2185
@ -1,4 +1,5 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||||
|
# Mingshuang Luo)
|
||||||
#
|
#
|
||||||
# See ../../LICENSE for clarification regarding multiple authors
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -17,6 +18,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import collections
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@ -29,6 +31,7 @@ import k2
|
|||||||
import kaldialign
|
import kaldialign
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
Pathlike = Union[str, Path]
|
Pathlike = Union[str, Path]
|
||||||
|
|
||||||
@ -419,3 +422,73 @@ def write_error_stats(
|
|||||||
|
|
||||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
return float(tot_err_rate)
|
return float(tot_err_rate)
|
||||||
|
|
||||||
|
|
||||||
|
class LossRecord(collections.defaultdict):
|
||||||
|
def __init__(self):
|
||||||
|
# Passing the type 'int' to the base-class constructor
|
||||||
|
# makes undefined items default to int() which is zero.
|
||||||
|
super(LossRecord, self).__init__(int)
|
||||||
|
|
||||||
|
def __add__(self, other: 'LossRecord') -> 'LossRecord':
|
||||||
|
ans = LossRecord()
|
||||||
|
for k, v in self.items():
|
||||||
|
ans[k] = v
|
||||||
|
for k, v in other.items():
|
||||||
|
ans[k] = ans[k] + v
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def __mul__(self, alpha: float) -> 'LossRecord':
|
||||||
|
ans = LossRecord()
|
||||||
|
for k, v in self.items():
|
||||||
|
ans[k] = v * alpha
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
ans = ''
|
||||||
|
for k, v in self.norm_items():
|
||||||
|
norm_value = '%.4g' % v
|
||||||
|
ans += (str(k) + '=' + str(norm_value) + ', ')
|
||||||
|
frames = str(self['frames'])
|
||||||
|
ans += 'over ' + frames + ' frames.'
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def norm_items(self) -> List[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Returns a list of pairs, like:
|
||||||
|
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
||||||
|
"""
|
||||||
|
num_frames = self['frames'] if 'frames' in self else 1
|
||||||
|
ans = []
|
||||||
|
for k, v in self.items():
|
||||||
|
if k != 'frames':
|
||||||
|
norm_value = float(v) / num_frames
|
||||||
|
ans.append((k, norm_value))
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def reduce(self, device):
|
||||||
|
"""
|
||||||
|
Reduce using torch.distributed, which I believe ensures that
|
||||||
|
all processes get the total.
|
||||||
|
"""
|
||||||
|
keys = sorted(self.keys())
|
||||||
|
s = torch.tensor([ float(self[k]) for k in keys ],
|
||||||
|
device=device)
|
||||||
|
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||||
|
for k, v in zip(keys, s.cpu().tolist()):
|
||||||
|
self[k] = v
|
||||||
|
|
||||||
|
def write_summary(self,
|
||||||
|
tb_writer: SummaryWriter,
|
||||||
|
prefix: str,
|
||||||
|
batch_idx: int) -> None:
|
||||||
|
"""Add logging information to a TensorBoard writer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tb_writer: a TensorBoard writer
|
||||||
|
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
||||||
|
or "train/current_"
|
||||||
|
batch_idx: The current batch index, used as the x-axis of the plot.
|
||||||
|
"""
|
||||||
|
for k, v in self.norm_items():
|
||||||
|
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user