mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Update train.py
This commit is contained in:
parent
597ff01158
commit
34e36a926b
@ -19,8 +19,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import collections
|
|
||||||
import copy
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@ -49,6 +47,7 @@ from icefall.dist import cleanup_dist, setup_dist
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
|
LossRecord,
|
||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
@ -287,73 +286,6 @@ def save_checkpoint(
|
|||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
class LossRecord(collections.defaultdict):
|
|
||||||
def __init__(self):
|
|
||||||
# Passing the type 'int' to the base-class constructor
|
|
||||||
# makes undefined items default to int() which is zero.
|
|
||||||
super(LossRecord, self).__init__(int)
|
|
||||||
|
|
||||||
def __add__(self, other: 'LossRecord') -> 'LossRecord':
|
|
||||||
ans = LossRecord()
|
|
||||||
for k, v in self.items():
|
|
||||||
ans[k] = v
|
|
||||||
for k, v in other.items():
|
|
||||||
ans[k] = ans[k] + v
|
|
||||||
return ans
|
|
||||||
|
|
||||||
def __mul__(self, alpha: float) -> 'LossRecord':
|
|
||||||
ans = LossRecord()
|
|
||||||
for k, v in self.items():
|
|
||||||
ans[k] = v * alpha
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
ans = ''
|
|
||||||
for k, v in self.norm_items():
|
|
||||||
norm_value = '%.4g' % v
|
|
||||||
ans += (str(k) + '=' + str(norm_value) + ', ')
|
|
||||||
frames = str(self['frames'])
|
|
||||||
ans += 'over ' + frames + ' frames.'
|
|
||||||
return ans
|
|
||||||
|
|
||||||
def norm_items(self) -> List[Tuple[str, float]]:
|
|
||||||
"""
|
|
||||||
Returns a list of pairs, like:
|
|
||||||
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
|
||||||
"""
|
|
||||||
num_frames = self['frames'] if 'frames' in self else 1
|
|
||||||
ans = []
|
|
||||||
for k, v in self.items():
|
|
||||||
if k != 'frames':
|
|
||||||
norm_value = float(v) / num_frames
|
|
||||||
ans.append((k, norm_value))
|
|
||||||
return ans
|
|
||||||
|
|
||||||
def reduce(self, device):
|
|
||||||
"""
|
|
||||||
Reduce using torch.distributed, which I believe ensures that
|
|
||||||
all processes get the total.
|
|
||||||
"""
|
|
||||||
keys = sorted(self.keys())
|
|
||||||
s = torch.tensor([ float(self[k]) for k in keys ],
|
|
||||||
device=device)
|
|
||||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
|
||||||
for k, v in zip(keys, s.cpu().tolist()):
|
|
||||||
self[k] = v
|
|
||||||
|
|
||||||
def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None:
|
|
||||||
"""
|
|
||||||
Add logging information to a TensorBoard writer.
|
|
||||||
tb_writer: a TensorBoard writer
|
|
||||||
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
|
||||||
or "train/current_"
|
|
||||||
batch_idx: The current batch index, used as the x-axis of the plot.
|
|
||||||
"""
|
|
||||||
for k, v in self.norm_items():
|
|
||||||
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -443,7 +375,6 @@ def compute_loss(
|
|||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = LossRecord()
|
info = LossRecord()
|
||||||
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
|
|
||||||
info['frames'] = supervision_segments[:, 2].sum().item()
|
info['frames'] = supervision_segments[:, 2].sum().item()
|
||||||
info['ctc_loss'] = ctc_loss.detach().cpu().item()
|
info['ctc_loss'] = ctc_loss.detach().cpu().item()
|
||||||
if params.att_rate != 0.0:
|
if params.att_rate != 0.0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user