Update train.py

This commit is contained in:
Mingshuang Luo 2021-09-29 12:49:59 +08:00 committed by GitHub
parent e74e75acc6
commit 0fa46bf68a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,7 +14,6 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch import Tensor from torch import Tensor
from asr_datamodule import YesNoAsrDataModule from asr_datamodule import YesNoAsrDataModule
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Tdnn from model import Tdnn
@ -27,7 +26,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, setup_logger, str2bool from icefall.utils import AttributeDict, LossRecord, setup_logger, str2bool
def get_parser(): def get_parser():
@ -245,72 +244,6 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
class LossRecord(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
super(LossRecord, self).__init__(int)
def __add__(self, other: 'LossRecord') -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ''
for k, v in self.norm_items():
norm_value = '%.4g' % v
ans += (str(k) + '=' + str(norm_value) + ', ')
frames = str(self['frames'])
ans += 'over ' + frames + ' frames.'
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('ctc_loss', 0.1), ('att_loss', 0.07)]
"""
num_frames = self['frames'] if 'frames' in self else 1
ans = []
for k, v in self.items():
if k != 'frames':
norm_value = float(v) / num_frames
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([ float(self[k]) for k in keys ],
device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None:
"""
Add logging information to a TensorBoard writer.
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -376,9 +309,8 @@ def compute_loss(
) )
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = LossRecord() info = LossRecord()
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
info['frames'] = supervision_segments[:, 2].sum().item() info['frames'] = supervision_segments[:, 2].sum().item()
info['loss'] = loss.detach().cpu().item() info['loss'] = loss.detach().cpu().item()
@ -398,7 +330,7 @@ def compute_validation_loss(
model.eval() model.eval()
tot_loss = LossRecord() tot_loss = LossRecord()
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
@ -410,7 +342,7 @@ def compute_validation_loss(
assert loss.requires_grad is False assert loss.requires_grad is False
tot_loss = tot_loss + loss_info tot_loss = tot_loss + loss_info
if world_size > 1: if world_size > 1:
tot_loss.reduce(loss.device) tot_loss.reduce(loss.device)
@ -458,7 +390,7 @@ def train_one_epoch(
Number of nodes in DDP training. If it is 1, DDP is disabled. Number of nodes in DDP training. If it is 1, DDP is disabled.
""" """
model.train() model.train()
tot_loss = LossRecord() tot_loss = LossRecord()
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
@ -473,10 +405,7 @@ def train_one_epoch(
is_training=True, is_training=True,
) )
# summary stats. # summary stats.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()