Use LossRecord to record and print the loss for the training process (#62)

* Update index.rst (AS->ASR)

* Update conformer_ctc.rst (pretraind->pretrained)

* Fix some spelling errors.

* Fix some spelling errors.

* Use LossRecord to record and print loss in the training process

* Change the name "LossRecord" to "MetricsTracker"
This commit is contained in:
Mingshuang Luo 2021-10-12 15:58:03 +08:00 committed by GitHub
parent beb54ddb61
commit 597c5efdb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 222 additions and 274 deletions

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang) # Wei Kang
# Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -21,13 +22,15 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional from typing import Optional, Tuple
import k2 import k2
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
@ -43,6 +46,7 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker,
encode_supervisions, encode_supervisions,
setup_logger, setup_logger,
str2bool, str2bool,
@ -287,7 +291,7 @@ def compute_loss(
batch: dict, batch: dict,
graph_compiler: BpeCtcTrainingGraphCompiler, graph_compiler: BpeCtcTrainingGraphCompiler,
is_training: bool, is_training: bool,
): ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -367,15 +371,17 @@ def compute_loss(
loss = ctc_loss loss = ctc_loss
att_loss = torch.tensor([0]) att_loss = torch.tensor([0])
# train_frames and valid_frames are used for printing.
if is_training:
params.train_frames = supervision_segments[:, 2].sum().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
return loss, ctc_loss.detach(), att_loss.detach() info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item()
info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.att_rate != 0.0:
info["att_loss"] = att_loss.detach().cpu().item()
info["loss"] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss( def compute_validation_loss(
@ -384,18 +390,14 @@ def compute_validation_loss(
graph_compiler: BpeCtcTrainingGraphCompiler, graph_compiler: BpeCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> None: ) -> MetricsTracker:
"""Run the validation process. The validation loss """Run the validation process."""
is saved in `params.valid_loss`.
"""
model.eval() model.eval()
tot_loss = 0.0 tot_loss = MetricsTracker()
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
loss, ctc_loss, att_loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
@ -403,36 +405,17 @@ def compute_validation_loss(
is_training=False, is_training=False,
) )
assert loss.requires_grad is False assert loss.requires_grad is False
assert ctc_loss.requires_grad is False tot_loss = tot_loss + loss_info
assert att_loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss.detach().cpu().item()
tot_att_loss += att_loss.detach().cpu().item()
tot_frames += params.valid_frames
if world_size > 1: if world_size > 1:
s = torch.tensor( tot_loss.reduce(loss.device)
[tot_loss, tot_ctc_loss, tot_att_loss, tot_frames],
device=loss.device,
)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_ctc_loss = s[1]
tot_att_loss = s[2]
tot_frames = s[3]
params.valid_loss = tot_loss / tot_frames loss_value = tot_loss["loss"] / tot_loss["frames"]
params.valid_ctc_loss = tot_ctc_loss / tot_frames if loss_value < params.best_valid_loss:
params.valid_att_loss = tot_att_loss / tot_frames
if params.valid_loss < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch( def train_one_epoch(
@ -471,24 +454,21 @@ def train_one_epoch(
""" """
model.train() model.train()
tot_loss = 0.0 # sum of losses over all batches tot_loss = MetricsTracker()
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
loss, ctc_loss, att_loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=True, is_training=True,
) )
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
@ -498,75 +478,26 @@ def train_one_epoch(
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
loss_cpu = loss.detach().cpu().item()
ctc_loss_cpu = ctc_loss.detach().cpu().item()
att_loss_cpu = att_loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss_cpu
tot_att_loss += att_loss_cpu
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"Epoch {params.cur_epoch}, "
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, " f"batch {batch_idx}, loss[{loss_info}], "
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, " f"tot_loss[{tot_loss}], batch size: {batch_size}"
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
f"total avg att loss: {tot_avg_att_loss:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
) )
if batch_idx % 10 == 0:
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( loss_info.write_summary(
"train/current_ctc_loss", tb_writer, "train/current_", params.batch_idx_train
ctc_loss_cpu / params.train_frames,
params.batch_idx_train,
) )
tb_writer.add_scalar( tot_loss.write_summary(
"train/current_att_loss", tb_writer, "train/tot_", params.batch_idx_train
att_loss_cpu / params.train_frames,
params.batch_idx_train,
) )
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_ctc_loss",
tot_avg_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_att_loss",
tot_avg_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss( logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
@ -574,33 +505,14 @@ def train_one_epoch(
world_size=world_size, world_size=world_size,
) )
model.train() model.train()
logging.info( logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
f"Epoch {params.cur_epoch}, "
f"valid ctc loss {params.valid_ctc_loss:.4f},"
f"valid att loss {params.valid_att_loss:.4f},"
f"valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( valid_info.write_summary(
"train/valid_ctc_loss", tb_writer, "train/valid_", params.batch_idx_train
params.valid_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_att_loss",
params.valid_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
params.batch_idx_train,
) )
params.train_loss = params.tot_loss / params.tot_frames loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss params.best_train_loss = params.train_loss

View File

@ -57,13 +57,13 @@ log() {
log "dl_dir: $dl_dir" log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: Download LM" log "Stage -1: Download LM"
[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm [ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm
./local/download_lm.py --out-dir=$dl_dir/lm ./local/download_lm.py --out-dir=$dl_dir/lm
fi fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Download data" log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/LibriSpeech, # If you have pre-downloaded it to /path/to/LibriSpeech,
# you can create a symlink # you can create a symlink
@ -126,7 +126,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
fi fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "State 6: Prepare BPE based lang" log "Stage 6: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size} lang_dir=data/lang_bpe_${vocab_size}

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -20,14 +21,15 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional from typing import Optional, Tuple
import k2 import k2
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch import Tensor
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import TdnnLstm from model import TdnnLstm
@ -43,6 +45,7 @@ from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker,
encode_supervisions, encode_supervisions,
setup_logger, setup_logger,
str2bool, str2bool,
@ -267,7 +270,7 @@ def compute_loss(
batch: dict, batch: dict,
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
is_training: bool, is_training: bool,
): ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -324,13 +327,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
# train_frames and valid_frames are used for printing. info = MetricsTracker()
if is_training: info["frames"] = supervision_segments[:, 2].sum().item()
params.train_frames = supervision_segments[:, 2].sum().item() info["loss"] = loss.detach().cpu().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
return loss return loss, info
def compute_validation_loss( def compute_validation_loss(
@ -339,16 +340,16 @@ def compute_validation_loss(
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> None: ) -> MetricsTracker:
"""Run the validation process. The validation loss """Run the validation process. The validation loss
is saved in `params.valid_loss`. is saved in `params.valid_loss`.
""" """
model.eval() model.eval()
tot_loss = 0.0 tot_loss = MetricsTracker()
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
@ -357,22 +358,18 @@ def compute_validation_loss(
) )
assert loss.requires_grad is False assert loss.requires_grad is False
loss_cpu = loss.detach().cpu().item() tot_loss = tot_loss + loss_info
tot_loss += loss_cpu
tot_frames += params.valid_frames
if world_size > 1: if world_size > 1:
s = torch.tensor([tot_loss, tot_frames], device=loss.device) tot_loss.reduce(loss.device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_frames = s[1]
params.valid_loss = tot_loss / tot_frames loss_value = tot_loss["loss"] / tot_loss["frames"]
if params.valid_loss < params.best_valid_loss: if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch( def train_one_epoch(
@ -411,67 +408,45 @@ def train_one_epoch(
""" """
model.train() model.train()
tot_loss = 0.0 # reset after params.reset_interval of batches tot_loss = MetricsTracker()
tot_frames = 0.0 # reset after params.reset_interval of batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=True, is_training=True,
) )
# summary stats.
# NOTE: We use reduction==sum and loss is computed over utterances tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# in the batch and there is no normalization to it so far.
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
loss_cpu = loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"Epoch {params.cur_epoch}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, " f"batch {batch_idx}, loss[{loss_info}], "
f"total avg loss: {tot_avg_loss:.4f}, " f"tot_loss[{tot_loss}], batch size: {batch_size}"
f"batch size: {batch_size}"
) )
if batch_idx % 10 == 0:
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( loss_info.write_summary(
"train/current_loss", tb_writer, "train/current_", params.batch_idx_train
loss_cpu / params.train_frames,
params.batch_idx_train,
) )
tot_loss.write_summary(
tb_writer.add_scalar( tb_writer, "train/tot_", params.batch_idx_train
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
) )
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0
tot_frames = 0
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
@ -479,13 +454,16 @@ def train_one_epoch(
world_size=world_size, world_size=world_size,
) )
model.train() model.train()
logging.info( logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," if tb_writer is not None:
f" best valid loss: {params.best_valid_loss:.4f} " valid_info.write_summary(
f"best valid epoch: {params.best_valid_epoch}" tb_writer,
) "train/valid_",
params.batch_idx_train,
)
params.train_loss = params.tot_loss / params.tot_frames loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch

View File

@ -24,7 +24,7 @@ log() {
log "dl_dir: $dl_dir" log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Download data" log "Stage 0: Download data"
mkdir -p $dl_dir mkdir -p $dl_dir
if [ ! -f $dl_dir/waves_yesno/.completed ]; then if [ ! -f $dl_dir/waves_yesno/.completed ]; then

View File

@ -4,14 +4,14 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional from typing import Optional, Tuple
import k2 import k2
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch import Tensor
from asr_datamodule import YesNoAsrDataModule from asr_datamodule import YesNoAsrDataModule
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Tdnn from model import Tdnn
@ -24,7 +24,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser(): def get_parser():
@ -122,6 +122,8 @@ def get_params() -> AttributeDict:
- valid_interval: Run validation if batch_idx % valid_interval` is 0 - valid_interval: Run validation if batch_idx % valid_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- beam_size: It is used in k2.ctc_loss - beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss - reduction: It is used in k2.ctc_loss
@ -142,6 +144,7 @@ def get_params() -> AttributeDict:
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 10, "log_interval": 10,
"reset_interval": 20,
"valid_interval": 10, "valid_interval": 10,
"beam_size": 10, "beam_size": 10,
"reduction": "sum", "reduction": "sum",
@ -245,7 +248,7 @@ def compute_loss(
batch: dict, batch: dict,
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
is_training: bool, is_training: bool,
): ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -305,13 +308,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
# train_frames and valid_frames are used for printing. info = MetricsTracker()
if is_training: info["frames"] = supervision_segments[:, 2].sum().item()
params.train_frames = supervision_segments[:, 2].sum().item() info["loss"] = loss.detach().cpu().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
return loss return loss, info
def compute_validation_loss( def compute_validation_loss(
@ -320,16 +321,16 @@ def compute_validation_loss(
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> None: ) -> MetricsTracker:
"""Run the validation process. The validation loss """Run the validation process. The validation loss
is saved in `params.valid_loss`. is saved in `params.valid_loss`.
""" """
model.eval() model.eval()
tot_loss = 0.0 tot_loss = MetricsTracker()
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
@ -338,22 +339,18 @@ def compute_validation_loss(
) )
assert loss.requires_grad is False assert loss.requires_grad is False
loss_cpu = loss.detach().cpu().item() tot_loss = tot_loss + loss_info
tot_loss += loss_cpu
tot_frames += params.valid_frames
if world_size > 1: if world_size > 1:
s = torch.tensor([tot_loss, tot_frames], device=loss.device) tot_loss.reduce(loss.device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_frames = s[1]
params.valid_loss = tot_loss / tot_frames loss_value = tot_loss["loss"] / tot_loss["frames"]
if params.valid_loss < params.best_valid_loss: if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch( def train_one_epoch(
@ -392,57 +389,45 @@ def train_one_epoch(
""" """
model.train() model.train()
tot_loss = 0.0 # sum of losses over all batches tot_loss = MetricsTracker()
tot_frames = 0.0 # sum of frames over all batches
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=True, is_training=True,
) )
# summary stats.
# NOTE: We use reduction==sum and loss is computed over utterances tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# in the batch and there is no normalization to it so far.
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
loss_cpu = loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"Epoch {params.cur_epoch}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, " f"batch {batch_idx}, loss[{loss_info}], "
f"total avg loss: {tot_avg_loss:.4f}, " f"tot_loss[{tot_loss}], batch size: {batch_size}"
f"batch size: {batch_size}"
) )
if batch_idx % 10 == 0:
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( loss_info.write_summary(
"train/current_loss", tb_writer, "train/current_", params.batch_idx_train
loss_cpu / params.train_frames,
params.batch_idx_train,
) )
tot_loss.write_summary(
tb_writer.add_scalar( tb_writer, "train/tot_", params.batch_idx_train
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
) )
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
@ -450,19 +435,16 @@ def train_one_epoch(
world_size=world_size, world_size=world_size,
) )
model.train() model.train()
logging.info( logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( valid_info.write_summary(
"train/valid_loss", tb_writer,
params.valid_loss, "train/valid_",
params.batch_idx_train, params.batch_idx_train,
) )
params.train_loss = tot_loss / tot_frames loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
# #
# See ../../LICENSE for clarification regarding multiple authors # See ../../LICENSE for clarification regarding multiple authors
# #
@ -17,6 +18,7 @@
import argparse import argparse
import logging import logging
import collections
import os import os
import subprocess import subprocess
from collections import defaultdict from collections import defaultdict
@ -29,6 +31,7 @@ import k2
import kaldialign import kaldialign
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
Pathlike = Union[str, Path] Pathlike = Union[str, Path]
@ -166,8 +169,8 @@ def encode_supervisions(
supervisions: dict, subsampling_factor: int supervisions: dict, subsampling_factor: int
) -> Tuple[torch.Tensor, List[str]]: ) -> Tuple[torch.Tensor, List[str]]:
""" """
Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor, Encodes Lhotse's ``batch["supervisions"]`` dict into
and a list of transcription strings. a pair of torch Tensor, and a list of transcription strings.
The supervision tensor has shape ``(batch_size, 3)``. The supervision tensor has shape ``(batch_size, 3)``.
Its second dimension contains information about sequence index [0], Its second dimension contains information about sequence index [0],
@ -272,13 +275,13 @@ def write_error_stats(
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
reference words (2337 correct) reference words (2337 correct)
- The difference between the reference transcript and predicted results. - The difference between the reference transcript and predicted result.
An instance is given below:: An instance is given below::
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
The above example shows that the reference word is `EDISON`, but it is The above example shows that the reference word is `EDISON`,
predicted to `ADDISON` (a substitution error). but it is predicted to `ADDISON` (a substitution error).
Another example is:: Another example is::
@ -419,3 +422,76 @@ def write_error_stats(
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate) return float(tot_err_rate)
class MetricsTracker(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
# This class will play a role as metrics tracker.
# It can record many metrics, including but not limited to loss.
super(MetricsTracker, self).__init__(int)
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ""
for k, v in self.norm_items():
norm_value = "%.4g" % v
ans += str(k) + "=" + str(norm_value) + ", "
frames = str(self["frames"])
ans += "over " + frames + " frames."
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('ctc_loss', 0.1), ('att_loss', 0.07)]
"""
num_frames = self["frames"] if "frames" in self else 1
ans = []
for k, v in self.items():
if k != "frames":
norm_value = float(v) / num_frames
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([float(self[k]) for k in keys], device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(
self,
tb_writer: SummaryWriter,
prefix: str,
batch_idx: int,
) -> None:
"""Add logging information to a TensorBoard writer.
Args:
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)