mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
Use LossRecord to record and print loss for the training process
This commit is contained in:
parent
73f21a379b
commit
e74e75acc6
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
# Wei Kang)
|
# Wei Kang
|
||||||
|
# Mingshuang Luo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -18,16 +19,21 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import collections
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Optional
|
from typing import Optional, Tuple, List
|
||||||
|
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
@ -281,13 +287,80 @@ def save_checkpoint(
|
|||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
|
class LossRecord(collections.defaultdict):
|
||||||
|
def __init__(self):
|
||||||
|
# Passing the type 'int' to the base-class constructor
|
||||||
|
# makes undefined items default to int() which is zero.
|
||||||
|
super(LossRecord, self).__init__(int)
|
||||||
|
|
||||||
|
def __add__(self, other: 'LossRecord') -> 'LossRecord':
|
||||||
|
ans = LossRecord()
|
||||||
|
for k, v in self.items():
|
||||||
|
ans[k] = v
|
||||||
|
for k, v in other.items():
|
||||||
|
ans[k] = ans[k] + v
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def __mul__(self, alpha: float) -> 'LossRecord':
|
||||||
|
ans = LossRecord()
|
||||||
|
for k, v in self.items():
|
||||||
|
ans[k] = v * alpha
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
ans = ''
|
||||||
|
for k, v in self.norm_items():
|
||||||
|
norm_value = '%.4g' % v
|
||||||
|
ans += (str(k) + '=' + str(norm_value) + ', ')
|
||||||
|
frames = str(self['frames'])
|
||||||
|
ans += 'over ' + frames + ' frames.'
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def norm_items(self) -> List[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Returns a list of pairs, like:
|
||||||
|
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
||||||
|
"""
|
||||||
|
num_frames = self['frames'] if 'frames' in self else 1
|
||||||
|
ans = []
|
||||||
|
for k, v in self.items():
|
||||||
|
if k != 'frames':
|
||||||
|
norm_value = float(v) / num_frames
|
||||||
|
ans.append((k, norm_value))
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def reduce(self, device):
|
||||||
|
"""
|
||||||
|
Reduce using torch.distributed, which I believe ensures that
|
||||||
|
all processes get the total.
|
||||||
|
"""
|
||||||
|
keys = sorted(self.keys())
|
||||||
|
s = torch.tensor([ float(self[k]) for k in keys ],
|
||||||
|
device=device)
|
||||||
|
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||||
|
for k, v in zip(keys, s.cpu().tolist()):
|
||||||
|
self[k] = v
|
||||||
|
|
||||||
|
def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None:
|
||||||
|
"""
|
||||||
|
Add logging information to a TensorBoard writer.
|
||||||
|
tb_writer: a TensorBoard writer
|
||||||
|
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
||||||
|
or "train/current_"
|
||||||
|
batch_idx: The current batch index, used as the x-axis of the plot.
|
||||||
|
"""
|
||||||
|
for k, v in self.norm_items():
|
||||||
|
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
):
|
) -> Tuple[Tensor, LossRecord]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute CTC loss given the model and its inputs.
|
||||||
|
|
||||||
@ -367,15 +440,18 @@ def compute_loss(
|
|||||||
loss = ctc_loss
|
loss = ctc_loss
|
||||||
att_loss = torch.tensor([0])
|
att_loss = torch.tensor([0])
|
||||||
|
|
||||||
# train_frames and valid_frames are used for printing.
|
|
||||||
if is_training:
|
|
||||||
params.train_frames = supervision_segments[:, 2].sum().item()
|
|
||||||
else:
|
|
||||||
params.valid_frames = supervision_segments[:, 2].sum().item()
|
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
return loss, ctc_loss.detach(), att_loss.detach()
|
info = LossRecord()
|
||||||
|
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
|
||||||
|
info['frames'] = supervision_segments[:, 2].sum().item()
|
||||||
|
info['ctc_loss'] = ctc_loss.detach().cpu().item()
|
||||||
|
if params.att_rate != 0.0:
|
||||||
|
info['att_loss'] = att_loss.detach().cpu().item()
|
||||||
|
|
||||||
|
info['loss'] = loss.detach().cpu().item()
|
||||||
|
|
||||||
|
return loss, info
|
||||||
|
|
||||||
|
|
||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
@ -384,18 +460,14 @@ def compute_validation_loss(
|
|||||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
) -> None:
|
) -> LossRecord:
|
||||||
"""Run the validation process. The validation loss
|
"""Run the validation process."""
|
||||||
is saved in `params.valid_loss`.
|
|
||||||
"""
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
tot_loss = 0.0
|
tot_loss = LossRecord()
|
||||||
tot_ctc_loss = 0.0
|
|
||||||
tot_att_loss = 0.0
|
|
||||||
tot_frames = 0.0
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
loss, ctc_loss, att_loss = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
@ -403,36 +475,17 @@ def compute_validation_loss(
|
|||||||
is_training=False,
|
is_training=False,
|
||||||
)
|
)
|
||||||
assert loss.requires_grad is False
|
assert loss.requires_grad is False
|
||||||
assert ctc_loss.requires_grad is False
|
tot_loss = tot_loss + loss_info
|
||||||
assert att_loss.requires_grad is False
|
|
||||||
|
|
||||||
loss_cpu = loss.detach().cpu().item()
|
|
||||||
tot_loss += loss_cpu
|
|
||||||
|
|
||||||
tot_ctc_loss += ctc_loss.detach().cpu().item()
|
|
||||||
tot_att_loss += att_loss.detach().cpu().item()
|
|
||||||
|
|
||||||
tot_frames += params.valid_frames
|
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
s = torch.tensor(
|
tot_loss.reduce(loss.device)
|
||||||
[tot_loss, tot_ctc_loss, tot_att_loss, tot_frames],
|
|
||||||
device=loss.device,
|
|
||||||
)
|
|
||||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
|
||||||
s = s.cpu().tolist()
|
|
||||||
tot_loss = s[0]
|
|
||||||
tot_ctc_loss = s[1]
|
|
||||||
tot_att_loss = s[2]
|
|
||||||
tot_frames = s[3]
|
|
||||||
|
|
||||||
params.valid_loss = tot_loss / tot_frames
|
loss_value = tot_loss['loss'] / tot_loss['frames']
|
||||||
params.valid_ctc_loss = tot_ctc_loss / tot_frames
|
if loss_value < params.best_valid_loss:
|
||||||
params.valid_att_loss = tot_att_loss / tot_frames
|
|
||||||
|
|
||||||
if params.valid_loss < params.best_valid_loss:
|
|
||||||
params.best_valid_epoch = params.cur_epoch
|
params.best_valid_epoch = params.cur_epoch
|
||||||
params.best_valid_loss = params.valid_loss
|
params.best_valid_loss = loss_value
|
||||||
|
|
||||||
|
return tot_loss
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
@ -471,24 +524,21 @@ def train_one_epoch(
|
|||||||
"""
|
"""
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
tot_loss = 0.0 # sum of losses over all batches
|
tot_loss = LossRecord()
|
||||||
tot_ctc_loss = 0.0
|
|
||||||
tot_att_loss = 0.0
|
|
||||||
|
|
||||||
tot_frames = 0.0 # sum of frames over all batches
|
|
||||||
params.tot_loss = 0.0
|
|
||||||
params.tot_frames = 0.0
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
loss, ctc_loss, att_loss = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
)
|
)
|
||||||
|
# summary stats
|
||||||
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
@ -498,75 +548,21 @@ def train_one_epoch(
|
|||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
loss_cpu = loss.detach().cpu().item()
|
|
||||||
ctc_loss_cpu = ctc_loss.detach().cpu().item()
|
|
||||||
att_loss_cpu = att_loss.detach().cpu().item()
|
|
||||||
|
|
||||||
tot_frames += params.train_frames
|
|
||||||
tot_loss += loss_cpu
|
|
||||||
tot_ctc_loss += ctc_loss_cpu
|
|
||||||
tot_att_loss += att_loss_cpu
|
|
||||||
|
|
||||||
params.tot_frames += params.train_frames
|
|
||||||
params.tot_loss += loss_cpu
|
|
||||||
|
|
||||||
tot_avg_loss = tot_loss / tot_frames
|
|
||||||
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
|
|
||||||
tot_avg_att_loss = tot_att_loss / tot_frames
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}]"
|
||||||
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
||||||
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
|
|
||||||
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
|
||||||
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
|
|
||||||
f"total avg att loss: {tot_avg_att_loss:.4f}, "
|
|
||||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
|
||||||
f"batch size: {batch_size}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if batch_idx % 10 == 0:
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar(
|
loss_info.write_summary(tb_writer, "train/current_", params.batch_idx_train)
|
||||||
"train/current_ctc_loss",
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
ctc_loss_cpu / params.train_frames,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/current_att_loss",
|
|
||||||
att_loss_cpu / params.train_frames,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/current_loss",
|
|
||||||
loss_cpu / params.train_frames,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/tot_avg_ctc_loss",
|
|
||||||
tot_avg_ctc_loss,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/tot_avg_att_loss",
|
|
||||||
tot_avg_att_loss,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/tot_avg_loss",
|
|
||||||
tot_avg_loss,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
|
||||||
tot_loss = 0.0 # sum of losses over all batches
|
|
||||||
tot_ctc_loss = 0.0
|
|
||||||
tot_att_loss = 0.0
|
|
||||||
|
|
||||||
tot_frames = 0.0 # sum of frames over all batches
|
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
compute_validation_loss(
|
logging.info("Computing validation loss")
|
||||||
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
@ -575,32 +571,13 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, validation: {valid_info}"
|
||||||
f"valid ctc loss {params.valid_ctc_loss:.4f},"
|
)
|
||||||
f"valid att loss {params.valid_att_loss:.4f},"
|
|
||||||
f"valid loss {params.valid_loss:.4f},"
|
|
||||||
f" best valid loss: {params.best_valid_loss:.4f} "
|
|
||||||
f"best valid epoch: {params.best_valid_epoch}"
|
|
||||||
)
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar(
|
valid_info.write_summary(tb_writer, "train/valid_", params.batch_idx_train)
|
||||||
"train/valid_ctc_loss",
|
|
||||||
params.valid_ctc_loss,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/valid_att_loss",
|
|
||||||
params.valid_att_loss,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/valid_loss",
|
|
||||||
params.valid_loss,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
|
|
||||||
params.train_loss = params.tot_loss / params.tot_frames
|
|
||||||
|
|
||||||
|
loss_value = tot_loss['loss'] / tot_loss['frames']
|
||||||
|
params.train_loss = loss_value
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
params.best_train_epoch = params.cur_epoch
|
params.best_train_epoch = params.cur_epoch
|
||||||
params.best_train_loss = params.train_loss
|
params.best_train_loss = params.train_loss
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||||
|
# Mingshuang Luo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -18,9 +19,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import collections
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Optional
|
from typing import Optional, Tuple, List
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
@ -28,6 +30,8 @@ import torch.distributed as dist
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import TdnnLstm
|
from model import TdnnLstm
|
||||||
@ -260,6 +264,71 @@ def save_checkpoint(
|
|||||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
class LossRecord(collections.defaultdict):
|
||||||
|
def __init__(self):
|
||||||
|
# Passing the type 'int' to the base-class constructor
|
||||||
|
# makes undefined items default to int() which is zero.
|
||||||
|
super(LossRecord, self).__init__(int)
|
||||||
|
|
||||||
|
def __add__(self, other: 'LossRecord') -> 'LossRecord':
|
||||||
|
ans = LossRecord()
|
||||||
|
for k, v in self.items():
|
||||||
|
ans[k] = v
|
||||||
|
for k, v in other.items():
|
||||||
|
ans[k] = ans[k] + v
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def __mul__(self, alpha: float) -> 'LossRecord':
|
||||||
|
ans = LossRecord()
|
||||||
|
for k, v in self.items():
|
||||||
|
ans[k] = v * alpha
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
ans = ''
|
||||||
|
for k, v in self.norm_items():
|
||||||
|
norm_value = '%.4g' % v
|
||||||
|
ans += (str(k) + '=' + str(norm_value) + ', ')
|
||||||
|
frames = str(self['frames'])
|
||||||
|
ans += 'over ' + frames + ' frames.'
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def norm_items(self) -> List[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Returns a list of pairs, like:
|
||||||
|
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
||||||
|
"""
|
||||||
|
num_frames = self['frames'] if 'frames' in self else 1
|
||||||
|
ans = []
|
||||||
|
for k, v in self.items():
|
||||||
|
if k != 'frames':
|
||||||
|
norm_value = float(v) / num_frames
|
||||||
|
ans.append((k, norm_value))
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def reduce(self, device):
|
||||||
|
"""
|
||||||
|
Reduce using torch.distributed, which I believe ensures that
|
||||||
|
all processes get the total.
|
||||||
|
"""
|
||||||
|
keys = sorted(self.keys())
|
||||||
|
s = torch.tensor([ float(self[k]) for k in keys ],
|
||||||
|
device=device)
|
||||||
|
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||||
|
for k, v in zip(keys, s.cpu().tolist()):
|
||||||
|
self[k] = v
|
||||||
|
|
||||||
|
def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None:
|
||||||
|
"""
|
||||||
|
Add logging information to a TensorBoard writer.
|
||||||
|
tb_writer: a TensorBoard writer
|
||||||
|
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
||||||
|
or "train/current_"
|
||||||
|
batch_idx: The current batch index, used as the x-axis of the plot.
|
||||||
|
"""
|
||||||
|
for k, v in self.norm_items():
|
||||||
|
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
@ -267,7 +336,7 @@ def compute_loss(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
graph_compiler: CtcTrainingGraphCompiler,
|
graph_compiler: CtcTrainingGraphCompiler,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
):
|
) -> Tuple[Tensor, LossRecord]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute CTC loss given the model and its inputs.
|
||||||
|
|
||||||
@ -324,13 +393,12 @@ def compute_loss(
|
|||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
# train_frames and valid_frames are used for printing.
|
info = LossRecord()
|
||||||
if is_training:
|
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
|
||||||
params.train_frames = supervision_segments[:, 2].sum().item()
|
info['frames'] = supervision_segments[:, 2].sum().item()
|
||||||
else:
|
info['loss'] = loss.detach().cpu().item()
|
||||||
params.valid_frames = supervision_segments[:, 2].sum().item()
|
|
||||||
|
|
||||||
return loss
|
return loss, info
|
||||||
|
|
||||||
|
|
||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
@ -339,16 +407,16 @@ def compute_validation_loss(
|
|||||||
graph_compiler: CtcTrainingGraphCompiler,
|
graph_compiler: CtcTrainingGraphCompiler,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
) -> None:
|
) -> LossRecord:
|
||||||
"""Run the validation process. The validation loss
|
"""Run the validation process. The validation loss
|
||||||
is saved in `params.valid_loss`.
|
is saved in `params.valid_loss`.
|
||||||
"""
|
"""
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
tot_loss = 0.0
|
tot_loss = LossRecord()
|
||||||
tot_frames = 0.0
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
loss = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
@ -357,22 +425,18 @@ def compute_validation_loss(
|
|||||||
)
|
)
|
||||||
assert loss.requires_grad is False
|
assert loss.requires_grad is False
|
||||||
|
|
||||||
loss_cpu = loss.detach().cpu().item()
|
tot_loss = tot_loss + loss_info
|
||||||
tot_loss += loss_cpu
|
|
||||||
tot_frames += params.valid_frames
|
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
s = torch.tensor([tot_loss, tot_frames], device=loss.device)
|
tot_loss.reduce(loss.device)
|
||||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
|
||||||
s = s.cpu().tolist()
|
|
||||||
tot_loss = s[0]
|
|
||||||
tot_frames = s[1]
|
|
||||||
|
|
||||||
params.valid_loss = tot_loss / tot_frames
|
loss_value = tot_loss['loss'] / tot_loss['frames']
|
||||||
|
|
||||||
if params.valid_loss < params.best_valid_loss:
|
if loss_value < params.best_valid_loss:
|
||||||
params.best_valid_epoch = params.cur_epoch
|
params.best_valid_epoch = params.cur_epoch
|
||||||
params.best_valid_loss = params.valid_loss
|
params.best_valid_loss = loss_value
|
||||||
|
|
||||||
|
return tot_loss
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
@ -411,23 +475,21 @@ def train_one_epoch(
|
|||||||
"""
|
"""
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
tot_loss = 0.0 # reset after params.reset_interval of batches
|
tot_loss = LossRecord()
|
||||||
tot_frames = 0.0 # reset after params.reset_interval of batches
|
|
||||||
|
|
||||||
params.tot_loss = 0.0
|
|
||||||
params.tot_frames = 0.0
|
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
loss = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
)
|
)
|
||||||
|
# summary stats.
|
||||||
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
@ -437,41 +499,19 @@ def train_one_epoch(
|
|||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
loss_cpu = loss.detach().cpu().item()
|
|
||||||
|
|
||||||
tot_frames += params.train_frames
|
|
||||||
tot_loss += loss_cpu
|
|
||||||
tot_avg_loss = tot_loss / tot_frames
|
|
||||||
|
|
||||||
params.tot_frames += params.train_frames
|
|
||||||
params.tot_loss += loss_cpu
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}]"
|
||||||
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
||||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
|
||||||
f"batch size: {batch_size}"
|
|
||||||
)
|
)
|
||||||
|
if batch_idx % 10 == 0:
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar(
|
loss_info.write_summary(tb_writer, "train/current_", params.batch_idx_train)
|
||||||
"train/current_loss",
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
loss_cpu / params.train_frames,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/tot_avg_loss",
|
|
||||||
tot_avg_loss,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
|
||||||
tot_loss = 0
|
|
||||||
tot_frames = 0
|
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
@ -480,12 +520,17 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
|
f"Epoch {params.cur_epoch}, validation {valid_info}"
|
||||||
f" best valid loss: {params.best_valid_loss:.4f} "
|
)
|
||||||
f"best valid epoch: {params.best_valid_epoch}"
|
if tb_writer is not None:
|
||||||
)
|
valid_info.write_summary(
|
||||||
|
tb_writer,
|
||||||
|
"train/valid_",
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
params.train_loss = params.tot_loss / params.tot_frames
|
loss_value = tot_loss['loss'] / tot_loss['frames']
|
||||||
|
params.train_loss = loss_value
|
||||||
|
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
params.best_train_epoch = params.cur_epoch
|
params.best_train_epoch = params.cur_epoch
|
||||||
|
@ -2,9 +2,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import collections
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Optional
|
from typing import Optional, Tuple, List
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
@ -12,6 +13,8 @@ import torch.distributed as dist
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from asr_datamodule import YesNoAsrDataModule
|
from asr_datamodule import YesNoAsrDataModule
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Tdnn
|
from model import Tdnn
|
||||||
@ -122,6 +125,8 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- valid_interval: Run validation if batch_idx % valid_interval` is 0
|
- valid_interval: Run validation if batch_idx % valid_interval` is 0
|
||||||
|
|
||||||
|
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||||
|
|
||||||
- beam_size: It is used in k2.ctc_loss
|
- beam_size: It is used in k2.ctc_loss
|
||||||
|
|
||||||
- reduction: It is used in k2.ctc_loss
|
- reduction: It is used in k2.ctc_loss
|
||||||
@ -142,6 +147,7 @@ def get_params() -> AttributeDict:
|
|||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 10,
|
"log_interval": 10,
|
||||||
|
"reset_interval": 20,
|
||||||
"valid_interval": 10,
|
"valid_interval": 10,
|
||||||
"beam_size": 10,
|
"beam_size": 10,
|
||||||
"reduction": "sum",
|
"reduction": "sum",
|
||||||
@ -239,13 +245,79 @@ def save_checkpoint(
|
|||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
|
class LossRecord(collections.defaultdict):
|
||||||
|
def __init__(self):
|
||||||
|
# Passing the type 'int' to the base-class constructor
|
||||||
|
# makes undefined items default to int() which is zero.
|
||||||
|
super(LossRecord, self).__init__(int)
|
||||||
|
|
||||||
|
def __add__(self, other: 'LossRecord') -> 'LossRecord':
|
||||||
|
ans = LossRecord()
|
||||||
|
for k, v in self.items():
|
||||||
|
ans[k] = v
|
||||||
|
for k, v in other.items():
|
||||||
|
ans[k] = ans[k] + v
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def __mul__(self, alpha: float) -> 'LossRecord':
|
||||||
|
ans = LossRecord()
|
||||||
|
for k, v in self.items():
|
||||||
|
ans[k] = v * alpha
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
ans = ''
|
||||||
|
for k, v in self.norm_items():
|
||||||
|
norm_value = '%.4g' % v
|
||||||
|
ans += (str(k) + '=' + str(norm_value) + ', ')
|
||||||
|
frames = str(self['frames'])
|
||||||
|
ans += 'over ' + frames + ' frames.'
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def norm_items(self) -> List[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Returns a list of pairs, like:
|
||||||
|
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
||||||
|
"""
|
||||||
|
num_frames = self['frames'] if 'frames' in self else 1
|
||||||
|
ans = []
|
||||||
|
for k, v in self.items():
|
||||||
|
if k != 'frames':
|
||||||
|
norm_value = float(v) / num_frames
|
||||||
|
ans.append((k, norm_value))
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def reduce(self, device):
|
||||||
|
"""
|
||||||
|
Reduce using torch.distributed, which I believe ensures that
|
||||||
|
all processes get the total.
|
||||||
|
"""
|
||||||
|
keys = sorted(self.keys())
|
||||||
|
s = torch.tensor([ float(self[k]) for k in keys ],
|
||||||
|
device=device)
|
||||||
|
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||||
|
for k, v in zip(keys, s.cpu().tolist()):
|
||||||
|
self[k] = v
|
||||||
|
|
||||||
|
def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None:
|
||||||
|
"""
|
||||||
|
Add logging information to a TensorBoard writer.
|
||||||
|
tb_writer: a TensorBoard writer
|
||||||
|
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
||||||
|
or "train/current_"
|
||||||
|
batch_idx: The current batch index, used as the x-axis of the plot.
|
||||||
|
"""
|
||||||
|
for k, v in self.norm_items():
|
||||||
|
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
graph_compiler: CtcTrainingGraphCompiler,
|
graph_compiler: CtcTrainingGraphCompiler,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
):
|
) -> Tuple[Tensor, LossRecord]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute CTC loss given the model and its inputs.
|
||||||
|
|
||||||
@ -305,13 +377,12 @@ def compute_loss(
|
|||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
# train_frames and valid_frames are used for printing.
|
info = LossRecord()
|
||||||
if is_training:
|
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
|
||||||
params.train_frames = supervision_segments[:, 2].sum().item()
|
info['frames'] = supervision_segments[:, 2].sum().item()
|
||||||
else:
|
info['loss'] = loss.detach().cpu().item()
|
||||||
params.valid_frames = supervision_segments[:, 2].sum().item()
|
|
||||||
|
|
||||||
return loss
|
return loss, info
|
||||||
|
|
||||||
|
|
||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
@ -320,16 +391,16 @@ def compute_validation_loss(
|
|||||||
graph_compiler: CtcTrainingGraphCompiler,
|
graph_compiler: CtcTrainingGraphCompiler,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
) -> None:
|
) -> LossRecord:
|
||||||
"""Run the validation process. The validation loss
|
"""Run the validation process. The validation loss
|
||||||
is saved in `params.valid_loss`.
|
is saved in `params.valid_loss`.
|
||||||
"""
|
"""
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
tot_loss = 0.0
|
tot_loss = LossRecord()
|
||||||
tot_frames = 0.0
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
loss = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
@ -338,22 +409,18 @@ def compute_validation_loss(
|
|||||||
)
|
)
|
||||||
assert loss.requires_grad is False
|
assert loss.requires_grad is False
|
||||||
|
|
||||||
loss_cpu = loss.detach().cpu().item()
|
tot_loss = tot_loss + loss_info
|
||||||
tot_loss += loss_cpu
|
|
||||||
tot_frames += params.valid_frames
|
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
s = torch.tensor([tot_loss, tot_frames], device=loss.device)
|
tot_loss.reduce(loss.device)
|
||||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
|
||||||
s = s.cpu().tolist()
|
|
||||||
tot_loss = s[0]
|
|
||||||
tot_frames = s[1]
|
|
||||||
|
|
||||||
params.valid_loss = tot_loss / tot_frames
|
loss_value = tot_loss['loss'] / tot_loss['frames']
|
||||||
|
|
||||||
if params.valid_loss < params.best_valid_loss:
|
if loss_value < params.best_valid_loss:
|
||||||
params.best_valid_epoch = params.cur_epoch
|
params.best_valid_epoch = params.cur_epoch
|
||||||
params.best_valid_loss = params.valid_loss
|
params.best_valid_loss = loss_value
|
||||||
|
|
||||||
|
return tot_loss
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
@ -392,19 +459,21 @@ def train_one_epoch(
|
|||||||
"""
|
"""
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
tot_loss = 0.0 # sum of losses over all batches
|
tot_loss = LossRecord()
|
||||||
tot_frames = 0.0 # sum of frames over all batches
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
loss = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
)
|
)
|
||||||
|
# summary stats.
|
||||||
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
@ -414,35 +483,19 @@ def train_one_epoch(
|
|||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
loss_cpu = loss.detach().cpu().item()
|
|
||||||
|
|
||||||
tot_frames += params.train_frames
|
|
||||||
tot_loss += loss_cpu
|
|
||||||
tot_avg_loss = tot_loss / tot_frames
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}]"
|
||||||
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
||||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
|
||||||
f"batch size: {batch_size}"
|
|
||||||
)
|
)
|
||||||
|
if batch_idx % 10 == 0:
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar(
|
loss_info.write_summary(tb_writer, "train/current_", params.batch_idx_train)
|
||||||
"train/current_loss",
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
loss_cpu / params.train_frames,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/tot_avg_loss",
|
|
||||||
tot_avg_loss,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
@ -451,18 +504,17 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
|
f"Epoch {params.cur_epoch}, validation {valid_info}"
|
||||||
f" best valid loss: {params.best_valid_loss:.4f} "
|
)
|
||||||
f"best valid epoch: {params.best_valid_epoch}"
|
|
||||||
)
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar(
|
valid_info.write_summary(
|
||||||
"train/valid_loss",
|
tb_writer,
|
||||||
params.valid_loss,
|
"train/valid_",
|
||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
params.train_loss = tot_loss / tot_frames
|
loss_value = tot_loss['loss'] / tot_loss['frames']
|
||||||
|
params.train_loss = loss_value
|
||||||
|
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
params.best_train_epoch = params.cur_epoch
|
params.best_train_epoch = params.cur_epoch
|
||||||
|
Loading…
x
Reference in New Issue
Block a user