Update train.py

This commit is contained in:
Mingshuang Luo 2021-09-29 12:49:59 +08:00 committed by GitHub
parent e74e75acc6
commit 0fa46bf68a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,7 +14,6 @@ import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from asr_datamodule import YesNoAsrDataModule
from lhotse.utils import fix_random_seed
from model import Tdnn
@ -27,7 +26,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, setup_logger, str2bool
from icefall.utils import AttributeDict, LossRecord, setup_logger, str2bool
def get_parser():
@ -245,72 +244,6 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename)
class LossRecord(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
super(LossRecord, self).__init__(int)
def __add__(self, other: 'LossRecord') -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ''
for k, v in self.norm_items():
norm_value = '%.4g' % v
ans += (str(k) + '=' + str(norm_value) + ', ')
frames = str(self['frames'])
ans += 'over ' + frames + ' frames.'
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('ctc_loss', 0.1), ('att_loss', 0.07)]
"""
num_frames = self['frames'] if 'frames' in self else 1
ans = []
for k, v in self.items():
if k != 'frames':
norm_value = float(v) / num_frames
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([ float(self[k]) for k in keys ],
device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None:
"""
Add logging information to a TensorBoard writer.
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
def compute_loss(
params: AttributeDict,
model: nn.Module,
@ -376,9 +309,8 @@ def compute_loss(
)
assert loss.requires_grad == is_training
info = LossRecord()
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
info['frames'] = supervision_segments[:, 2].sum().item()
info['loss'] = loss.detach().cpu().item()
@ -398,7 +330,7 @@ def compute_validation_loss(
model.eval()
tot_loss = LossRecord()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
@ -410,7 +342,7 @@ def compute_validation_loss(
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
@ -458,7 +390,7 @@ def train_one_epoch(
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = LossRecord()
for batch_idx, batch in enumerate(train_dl):
@ -473,10 +405,7 @@ def train_one_epoch(
is_training=True,
)
# summary stats.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
optimizer.zero_grad()
loss.backward()