Use LossRecord to record and print loss for the training process

This commit is contained in:
Mingshuang Luo 2021-09-29 10:08:38 +08:00
parent 73f21a379b
commit e74e75acc6
3 changed files with 337 additions and 263 deletions

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang) # Wei Kang
# Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -18,16 +19,21 @@
import argparse import argparse
import collections
import copy
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional from typing import Optional, Tuple, List
import k2 import k2
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
@ -281,13 +287,80 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
class LossRecord(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
super(LossRecord, self).__init__(int)
def __add__(self, other: 'LossRecord') -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ''
for k, v in self.norm_items():
norm_value = '%.4g' % v
ans += (str(k) + '=' + str(norm_value) + ', ')
frames = str(self['frames'])
ans += 'over ' + frames + ' frames.'
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('ctc_loss', 0.1), ('att_loss', 0.07)]
"""
num_frames = self['frames'] if 'frames' in self else 1
ans = []
for k, v in self.items():
if k != 'frames':
norm_value = float(v) / num_frames
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([ float(self[k]) for k in keys ],
device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None:
"""
Add logging information to a TensorBoard writer.
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
batch: dict, batch: dict,
graph_compiler: BpeCtcTrainingGraphCompiler, graph_compiler: BpeCtcTrainingGraphCompiler,
is_training: bool, is_training: bool,
): ) -> Tuple[Tensor, LossRecord]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -367,15 +440,18 @@ def compute_loss(
loss = ctc_loss loss = ctc_loss
att_loss = torch.tensor([0]) att_loss = torch.tensor([0])
# train_frames and valid_frames are used for printing.
if is_training:
params.train_frames = supervision_segments[:, 2].sum().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
return loss, ctc_loss.detach(), att_loss.detach() info = LossRecord()
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
info['frames'] = supervision_segments[:, 2].sum().item()
info['ctc_loss'] = ctc_loss.detach().cpu().item()
if params.att_rate != 0.0:
info['att_loss'] = att_loss.detach().cpu().item()
info['loss'] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss( def compute_validation_loss(
@ -384,18 +460,14 @@ def compute_validation_loss(
graph_compiler: BpeCtcTrainingGraphCompiler, graph_compiler: BpeCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> None: ) -> LossRecord:
"""Run the validation process. The validation loss """Run the validation process."""
is saved in `params.valid_loss`.
"""
model.eval() model.eval()
tot_loss = 0.0 tot_loss = LossRecord()
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
loss, ctc_loss, att_loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
@ -403,36 +475,17 @@ def compute_validation_loss(
is_training=False, is_training=False,
) )
assert loss.requires_grad is False assert loss.requires_grad is False
assert ctc_loss.requires_grad is False tot_loss = tot_loss + loss_info
assert att_loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss.detach().cpu().item()
tot_att_loss += att_loss.detach().cpu().item()
tot_frames += params.valid_frames
if world_size > 1: if world_size > 1:
s = torch.tensor( tot_loss.reduce(loss.device)
[tot_loss, tot_ctc_loss, tot_att_loss, tot_frames],
device=loss.device,
)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_ctc_loss = s[1]
tot_att_loss = s[2]
tot_frames = s[3]
params.valid_loss = tot_loss / tot_frames loss_value = tot_loss['loss'] / tot_loss['frames']
params.valid_ctc_loss = tot_ctc_loss / tot_frames if loss_value < params.best_valid_loss:
params.valid_att_loss = tot_att_loss / tot_frames
if params.valid_loss < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch( def train_one_epoch(
@ -471,24 +524,21 @@ def train_one_epoch(
""" """
model.train() model.train()
tot_loss = 0.0 # sum of losses over all batches tot_loss = LossRecord()
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
loss, ctc_loss, att_loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=True, is_training=True,
) )
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
@ -498,75 +548,21 @@ def train_one_epoch(
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
loss_cpu = loss.detach().cpu().item()
ctc_loss_cpu = ctc_loss.detach().cpu().item()
att_loss_cpu = att_loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss_cpu
tot_att_loss += att_loss_cpu
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}]"
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, " f"tot_loss[{tot_loss}], batch size: {batch_size}"
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
f"total avg att loss: {tot_avg_att_loss:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
) )
if batch_idx % 10 == 0:
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( loss_info.write_summary(tb_writer, "train/current_", params.batch_idx_train)
"train/current_ctc_loss", tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
ctc_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_att_loss",
att_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_ctc_loss",
tot_avg_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_att_loss",
tot_avg_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss( logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
@ -575,32 +571,13 @@ def train_one_epoch(
) )
model.train() model.train()
logging.info( logging.info(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, validation: {valid_info}"
f"valid ctc loss {params.valid_ctc_loss:.4f},"
f"valid att loss {params.valid_att_loss:.4f},"
f"valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
) )
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( valid_info.write_summary(tb_writer, "train/valid_", params.batch_idx_train)
"train/valid_ctc_loss",
params.valid_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_att_loss",
params.valid_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
params.batch_idx_train,
)
params.train_loss = params.tot_loss / params.tot_frames
loss_value = tot_loss['loss'] / tot_loss['frames']
params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss params.best_train_loss = params.train_loss

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -18,9 +19,10 @@
import argparse import argparse
import logging import logging
import collections
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional from typing import Optional, Tuple, List
import k2 import k2
import torch import torch
@ -28,6 +30,8 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch import Tensor
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import TdnnLstm from model import TdnnLstm
@ -260,6 +264,71 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt" best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
class LossRecord(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
super(LossRecord, self).__init__(int)
def __add__(self, other: 'LossRecord') -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ''
for k, v in self.norm_items():
norm_value = '%.4g' % v
ans += (str(k) + '=' + str(norm_value) + ', ')
frames = str(self['frames'])
ans += 'over ' + frames + ' frames.'
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('ctc_loss', 0.1), ('att_loss', 0.07)]
"""
num_frames = self['frames'] if 'frames' in self else 1
ans = []
for k, v in self.items():
if k != 'frames':
norm_value = float(v) / num_frames
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([ float(self[k]) for k in keys ],
device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None:
"""
Add logging information to a TensorBoard writer.
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
@ -267,7 +336,7 @@ def compute_loss(
batch: dict, batch: dict,
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
is_training: bool, is_training: bool,
): ) -> Tuple[Tensor, LossRecord]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -324,13 +393,12 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
# train_frames and valid_frames are used for printing. info = LossRecord()
if is_training: # TODO: there are many GPU->CPU transfers here, maybe combine them into one.
params.train_frames = supervision_segments[:, 2].sum().item() info['frames'] = supervision_segments[:, 2].sum().item()
else: info['loss'] = loss.detach().cpu().item()
params.valid_frames = supervision_segments[:, 2].sum().item()
return loss return loss, info
def compute_validation_loss( def compute_validation_loss(
@ -339,16 +407,16 @@ def compute_validation_loss(
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> None: ) -> LossRecord:
"""Run the validation process. The validation loss """Run the validation process. The validation loss
is saved in `params.valid_loss`. is saved in `params.valid_loss`.
""" """
model.eval() model.eval()
tot_loss = 0.0 tot_loss = LossRecord()
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
@ -357,22 +425,18 @@ def compute_validation_loss(
) )
assert loss.requires_grad is False assert loss.requires_grad is False
loss_cpu = loss.detach().cpu().item() tot_loss = tot_loss + loss_info
tot_loss += loss_cpu
tot_frames += params.valid_frames
if world_size > 1: if world_size > 1:
s = torch.tensor([tot_loss, tot_frames], device=loss.device) tot_loss.reduce(loss.device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_frames = s[1]
params.valid_loss = tot_loss / tot_frames loss_value = tot_loss['loss'] / tot_loss['frames']
if params.valid_loss < params.best_valid_loss: if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch( def train_one_epoch(
@ -411,23 +475,21 @@ def train_one_epoch(
""" """
model.train() model.train()
tot_loss = 0.0 # reset after params.reset_interval of batches tot_loss = LossRecord()
tot_frames = 0.0 # reset after params.reset_interval of batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=True, is_training=True,
) )
# summary stats.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
@ -437,41 +499,19 @@ def train_one_epoch(
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
loss_cpu = loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}]"
f"batch avg loss {loss_cpu/params.train_frames:.4f}, " f"tot_loss[{tot_loss}], batch size: {batch_size}"
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
) )
if batch_idx % 10 == 0:
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( loss_info.write_summary(tb_writer, "train/current_", params.batch_idx_train)
"train/current_loss", tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0
tot_frames = 0
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
@ -480,12 +520,17 @@ def train_one_epoch(
) )
model.train() model.train()
logging.info( logging.info(
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," f"Epoch {params.cur_epoch}, validation {valid_info}"
f" best valid loss: {params.best_valid_loss:.4f} " )
f"best valid epoch: {params.best_valid_epoch}" if tb_writer is not None:
valid_info.write_summary(
tb_writer,
"train/valid_",
params.batch_idx_train,
) )
params.train_loss = params.tot_loss / params.tot_frames loss_value = tot_loss['loss'] / tot_loss['frames']
params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch

View File

@ -2,9 +2,10 @@
import argparse import argparse
import logging import logging
import collections
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional from typing import Optional, Tuple, List
import k2 import k2
import torch import torch
@ -12,6 +13,8 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch import Tensor
from asr_datamodule import YesNoAsrDataModule from asr_datamodule import YesNoAsrDataModule
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Tdnn from model import Tdnn
@ -122,6 +125,8 @@ def get_params() -> AttributeDict:
- valid_interval: Run validation if batch_idx % valid_interval` is 0 - valid_interval: Run validation if batch_idx % valid_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- beam_size: It is used in k2.ctc_loss - beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss - reduction: It is used in k2.ctc_loss
@ -142,6 +147,7 @@ def get_params() -> AttributeDict:
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 10, "log_interval": 10,
"reset_interval": 20,
"valid_interval": 10, "valid_interval": 10,
"beam_size": 10, "beam_size": 10,
"reduction": "sum", "reduction": "sum",
@ -239,13 +245,79 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
class LossRecord(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
super(LossRecord, self).__init__(int)
def __add__(self, other: 'LossRecord') -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ''
for k, v in self.norm_items():
norm_value = '%.4g' % v
ans += (str(k) + '=' + str(norm_value) + ', ')
frames = str(self['frames'])
ans += 'over ' + frames + ' frames.'
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('ctc_loss', 0.1), ('att_loss', 0.07)]
"""
num_frames = self['frames'] if 'frames' in self else 1
ans = []
for k, v in self.items():
if k != 'frames':
norm_value = float(v) / num_frames
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([ float(self[k]) for k in keys ],
device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None:
"""
Add logging information to a TensorBoard writer.
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
batch: dict, batch: dict,
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
is_training: bool, is_training: bool,
): ) -> Tuple[Tensor, LossRecord]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -305,13 +377,12 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
# train_frames and valid_frames are used for printing. info = LossRecord()
if is_training: # TODO: there are many GPU->CPU transfers here, maybe combine them into one.
params.train_frames = supervision_segments[:, 2].sum().item() info['frames'] = supervision_segments[:, 2].sum().item()
else: info['loss'] = loss.detach().cpu().item()
params.valid_frames = supervision_segments[:, 2].sum().item()
return loss return loss, info
def compute_validation_loss( def compute_validation_loss(
@ -320,16 +391,16 @@ def compute_validation_loss(
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> None: ) -> LossRecord:
"""Run the validation process. The validation loss """Run the validation process. The validation loss
is saved in `params.valid_loss`. is saved in `params.valid_loss`.
""" """
model.eval() model.eval()
tot_loss = 0.0 tot_loss = LossRecord()
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
@ -338,22 +409,18 @@ def compute_validation_loss(
) )
assert loss.requires_grad is False assert loss.requires_grad is False
loss_cpu = loss.detach().cpu().item() tot_loss = tot_loss + loss_info
tot_loss += loss_cpu
tot_frames += params.valid_frames
if world_size > 1: if world_size > 1:
s = torch.tensor([tot_loss, tot_frames], device=loss.device) tot_loss.reduce(loss.device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_frames = s[1]
params.valid_loss = tot_loss / tot_frames loss_value = tot_loss['loss'] / tot_loss['frames']
if params.valid_loss < params.best_valid_loss: if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch( def train_one_epoch(
@ -392,19 +459,21 @@ def train_one_epoch(
""" """
model.train() model.train()
tot_loss = 0.0 # sum of losses over all batches tot_loss = LossRecord()
tot_frames = 0.0 # sum of frames over all batches
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=True, is_training=True,
) )
# summary stats.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
@ -414,35 +483,19 @@ def train_one_epoch(
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
loss_cpu = loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}]"
f"batch avg loss {loss_cpu/params.train_frames:.4f}, " f"tot_loss[{tot_loss}], batch size: {batch_size}"
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
) )
if batch_idx % 10 == 0:
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( loss_info.write_summary(tb_writer, "train/current_", params.batch_idx_train)
"train/current_loss", tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
@ -451,18 +504,17 @@ def train_one_epoch(
) )
model.train() model.train()
logging.info( logging.info(
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," f"Epoch {params.cur_epoch}, validation {valid_info}"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
) )
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( valid_info.write_summary(
"train/valid_loss", tb_writer,
params.valid_loss, "train/valid_",
params.batch_idx_train, params.batch_idx_train,
) )
params.train_loss = tot_loss / tot_frames loss_value = tot_loss['loss'] / tot_loss['frames']
params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch