Use LossRecord to record and print loss for the training process

This commit is contained in:
Mingshuang Luo 2021-09-29 10:08:38 +08:00
parent 73f21a379b
commit e74e75acc6
3 changed files with 337 additions and 263 deletions

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
# Wei Kang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -18,16 +19,21 @@
import argparse
import collections
import copy
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional
from typing import Optional, Tuple, List
import k2
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch import Tensor
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
@ -281,13 +287,80 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename)
class LossRecord(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
super(LossRecord, self).__init__(int)
def __add__(self, other: 'LossRecord') -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ''
for k, v in self.norm_items():
norm_value = '%.4g' % v
ans += (str(k) + '=' + str(norm_value) + ', ')
frames = str(self['frames'])
ans += 'over ' + frames + ' frames.'
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('ctc_loss', 0.1), ('att_loss', 0.07)]
"""
num_frames = self['frames'] if 'frames' in self else 1
ans = []
for k, v in self.items():
if k != 'frames':
norm_value = float(v) / num_frames
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([ float(self[k]) for k in keys ],
device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None:
"""
Add logging information to a TensorBoard writer.
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
graph_compiler: BpeCtcTrainingGraphCompiler,
is_training: bool,
):
) -> Tuple[Tensor, LossRecord]:
"""
Compute CTC loss given the model and its inputs.
@ -367,15 +440,18 @@ def compute_loss(
loss = ctc_loss
att_loss = torch.tensor([0])
# train_frames and valid_frames are used for printing.
if is_training:
params.train_frames = supervision_segments[:, 2].sum().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
assert loss.requires_grad == is_training
return loss, ctc_loss.detach(), att_loss.detach()
info = LossRecord()
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
info['frames'] = supervision_segments[:, 2].sum().item()
info['ctc_loss'] = ctc_loss.detach().cpu().item()
if params.att_rate != 0.0:
info['att_loss'] = att_loss.detach().cpu().item()
info['loss'] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
@ -384,18 +460,14 @@ def compute_validation_loss(
graph_compiler: BpeCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> None:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
) -> LossRecord:
"""Run the validation process."""
model.eval()
tot_loss = 0.0
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0
tot_loss = LossRecord()
for batch_idx, batch in enumerate(valid_dl):
loss, ctc_loss, att_loss = compute_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
@ -403,36 +475,17 @@ def compute_validation_loss(
is_training=False,
)
assert loss.requires_grad is False
assert ctc_loss.requires_grad is False
assert att_loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss.detach().cpu().item()
tot_att_loss += att_loss.detach().cpu().item()
tot_frames += params.valid_frames
tot_loss = tot_loss + loss_info
if world_size > 1:
s = torch.tensor(
[tot_loss, tot_ctc_loss, tot_att_loss, tot_frames],
device=loss.device,
)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_ctc_loss = s[1]
tot_att_loss = s[2]
tot_frames = s[3]
tot_loss.reduce(loss.device)
params.valid_loss = tot_loss / tot_frames
params.valid_ctc_loss = tot_ctc_loss / tot_frames
params.valid_att_loss = tot_att_loss / tot_frames
if params.valid_loss < params.best_valid_loss:
loss_value = tot_loss['loss'] / tot_loss['frames']
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
@ -471,24 +524,21 @@ def train_one_epoch(
"""
model.train()
tot_loss = 0.0 # sum of losses over all batches
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_loss = LossRecord()
tot_frames = 0.0 # sum of frames over all batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, ctc_loss, att_loss = compute_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
@ -498,75 +548,21 @@ def train_one_epoch(
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()
ctc_loss_cpu = ctc_loss.detach().cpu().item()
att_loss_cpu = att_loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss_cpu
tot_att_loss += att_loss_cpu
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, "
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
f"total avg att loss: {tot_avg_att_loss:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}]"
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % 10 == 0:
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_ctc_loss",
ctc_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_att_loss",
att_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_ctc_loss",
tot_avg_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_att_loss",
tot_avg_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
loss_info.write_summary(tb_writer, "train/current_", params.batch_idx_train)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
@ -575,32 +571,13 @@ def train_one_epoch(
)
model.train()
logging.info(
f"Epoch {params.cur_epoch}, "
f"valid ctc loss {params.valid_ctc_loss:.4f},"
f"valid att loss {params.valid_att_loss:.4f},"
f"valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
f"Epoch {params.cur_epoch}, validation: {valid_info}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/valid_ctc_loss",
params.valid_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_att_loss",
params.valid_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
params.batch_idx_train,
)
params.train_loss = params.tot_loss / params.tot_frames
valid_info.write_summary(tb_writer, "train/valid_", params.batch_idx_train)
loss_value = tot_loss['loss'] / tot_loss['frames']
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
@ -739,4 +716,4 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()
main()

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -18,9 +19,10 @@
import argparse
import logging
import collections
from pathlib import Path
from shutil import copyfile
from typing import Optional
from typing import Optional, Tuple, List
import k2
import torch
@ -28,6 +30,8 @@ import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from asr_datamodule import LibriSpeechAsrDataModule
from lhotse.utils import fix_random_seed
from model import TdnnLstm
@ -260,6 +264,71 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
class LossRecord(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
super(LossRecord, self).__init__(int)
def __add__(self, other: 'LossRecord') -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ''
for k, v in self.norm_items():
norm_value = '%.4g' % v
ans += (str(k) + '=' + str(norm_value) + ', ')
frames = str(self['frames'])
ans += 'over ' + frames + ' frames.'
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('ctc_loss', 0.1), ('att_loss', 0.07)]
"""
num_frames = self['frames'] if 'frames' in self else 1
ans = []
for k, v in self.items():
if k != 'frames':
norm_value = float(v) / num_frames
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([ float(self[k]) for k in keys ],
device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None:
"""
Add logging information to a TensorBoard writer.
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
def compute_loss(
params: AttributeDict,
@ -267,7 +336,7 @@ def compute_loss(
batch: dict,
graph_compiler: CtcTrainingGraphCompiler,
is_training: bool,
):
) -> Tuple[Tensor, LossRecord]:
"""
Compute CTC loss given the model and its inputs.
@ -324,13 +393,12 @@ def compute_loss(
assert loss.requires_grad == is_training
# train_frames and valid_frames are used for printing.
if is_training:
params.train_frames = supervision_segments[:, 2].sum().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
info = LossRecord()
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
info['frames'] = supervision_segments[:, 2].sum().item()
info['loss'] = loss.detach().cpu().item()
return loss
return loss, info
def compute_validation_loss(
@ -339,16 +407,16 @@ def compute_validation_loss(
graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> None:
) -> LossRecord:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = 0.0
tot_frames = 0.0
tot_loss = LossRecord()
for batch_idx, batch in enumerate(valid_dl):
loss = compute_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
@ -357,22 +425,18 @@ def compute_validation_loss(
)
assert loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_frames += params.valid_frames
tot_loss = tot_loss + loss_info
if world_size > 1:
s = torch.tensor([tot_loss, tot_frames], device=loss.device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_frames = s[1]
tot_loss.reduce(loss.device)
params.valid_loss = tot_loss / tot_frames
loss_value = tot_loss['loss'] / tot_loss['frames']
if params.valid_loss < params.best_valid_loss:
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
@ -411,23 +475,21 @@ def train_one_epoch(
"""
model.train()
tot_loss = 0.0 # reset after params.reset_interval of batches
tot_frames = 0.0 # reset after params.reset_interval of batches
params.tot_loss = 0.0
params.tot_frames = 0.0
tot_loss = LossRecord()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss = compute_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
# summary stats.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
@ -437,41 +499,19 @@ def train_one_epoch(
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}]"
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % 10 == 0:
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0
tot_frames = 0
loss_info.write_summary(tb_writer, "train/current_", params.batch_idx_train)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
@ -480,12 +520,17 @@ def train_one_epoch(
)
model.train()
logging.info(
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
f"Epoch {params.cur_epoch}, validation {valid_info}"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer,
"train/valid_",
params.batch_idx_train,
)
params.train_loss = params.tot_loss / params.tot_frames
loss_value = tot_loss['loss'] / tot_loss['frames']
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
@ -613,4 +658,4 @@ def main():
if __name__ == "__main__":
main()
main()

View File

@ -2,9 +2,10 @@
import argparse
import logging
import collections
from pathlib import Path
from shutil import copyfile
from typing import Optional
from typing import Optional, Tuple, List
import k2
import torch
@ -12,6 +13,8 @@ import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from asr_datamodule import YesNoAsrDataModule
from lhotse.utils import fix_random_seed
from model import Tdnn
@ -122,6 +125,8 @@ def get_params() -> AttributeDict:
- valid_interval: Run validation if batch_idx % valid_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss
@ -142,6 +147,7 @@ def get_params() -> AttributeDict:
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
"reset_interval": 20,
"valid_interval": 10,
"beam_size": 10,
"reduction": "sum",
@ -239,13 +245,79 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename)
class LossRecord(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
super(LossRecord, self).__init__(int)
def __add__(self, other: 'LossRecord') -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ''
for k, v in self.norm_items():
norm_value = '%.4g' % v
ans += (str(k) + '=' + str(norm_value) + ', ')
frames = str(self['frames'])
ans += 'over ' + frames + ' frames.'
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('ctc_loss', 0.1), ('att_loss', 0.07)]
"""
num_frames = self['frames'] if 'frames' in self else 1
ans = []
for k, v in self.items():
if k != 'frames':
norm_value = float(v) / num_frames
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([ float(self[k]) for k in keys ],
device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None:
"""
Add logging information to a TensorBoard writer.
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
graph_compiler: CtcTrainingGraphCompiler,
is_training: bool,
):
) -> Tuple[Tensor, LossRecord]:
"""
Compute CTC loss given the model and its inputs.
@ -304,14 +376,13 @@ def compute_loss(
)
assert loss.requires_grad == is_training
info = LossRecord()
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
info['frames'] = supervision_segments[:, 2].sum().item()
info['loss'] = loss.detach().cpu().item()
# train_frames and valid_frames are used for printing.
if is_training:
params.train_frames = supervision_segments[:, 2].sum().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
return loss
return loss, info
def compute_validation_loss(
@ -320,16 +391,16 @@ def compute_validation_loss(
graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> None:
) -> LossRecord:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = 0.0
tot_frames = 0.0
tot_loss = LossRecord()
for batch_idx, batch in enumerate(valid_dl):
loss = compute_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
@ -338,22 +409,18 @@ def compute_validation_loss(
)
assert loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_frames += params.valid_frames
tot_loss = tot_loss + loss_info
if world_size > 1:
s = torch.tensor([tot_loss, tot_frames], device=loss.device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_frames = s[1]
tot_loss.reduce(loss.device)
params.valid_loss = tot_loss / tot_frames
loss_value = tot_loss['loss'] / tot_loss['frames']
if params.valid_loss < params.best_valid_loss:
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
@ -391,20 +458,22 @@ def train_one_epoch(
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = LossRecord()
tot_loss = 0.0 # sum of losses over all batches
tot_frames = 0.0 # sum of frames over all batches
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss = compute_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
# summary stats.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
@ -414,35 +483,19 @@ def train_one_epoch(
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}]"
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % 10 == 0:
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
loss_info.write_summary(tb_writer, "train/current_", params.batch_idx_train)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
@ -451,18 +504,17 @@ def train_one_epoch(
)
model.train()
logging.info(
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
f"Epoch {params.cur_epoch}, validation {valid_info}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
valid_info.write_summary(
tb_writer,
"train/valid_",
params.batch_idx_train,
)
params.train_loss = tot_loss / tot_frames
loss_value = tot_loss['loss'] / tot_loss['frames']
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch