diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 80b2d924a..3e1049fbf 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) +# Wei Kang +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -21,13 +22,15 @@ import argparse import logging from pathlib import Path from shutil import copyfile -from typing import Optional +from typing import Optional, Tuple + import k2 import torch -import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn +from torch import Tensor + from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from lhotse.utils import fix_random_seed @@ -43,6 +46,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + MetricsTracker, encode_supervisions, setup_logger, str2bool, @@ -287,7 +291,7 @@ def compute_loss( batch: dict, graph_compiler: BpeCtcTrainingGraphCompiler, is_training: bool, -): +) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -367,15 +371,17 @@ def compute_loss( loss = ctc_loss att_loss = torch.tensor([0]) - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() - else: - params.valid_frames = supervision_segments[:, 2].sum().item() - assert loss.requires_grad == is_training - return loss, ctc_loss.detach(), att_loss.detach() + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + info["loss"] = loss.detach().cpu().item() + + return loss, info def compute_validation_loss( @@ -384,18 +390,14 @@ def compute_validation_loss( graph_compiler: BpeCtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> None: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ +) -> MetricsTracker: + """Run the validation process.""" model.eval() - tot_loss = 0.0 - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 - tot_frames = 0.0 + tot_loss = MetricsTracker() + for batch_idx, batch in enumerate(valid_dl): - loss, ctc_loss, att_loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, @@ -403,36 +405,17 @@ def compute_validation_loss( is_training=False, ) assert loss.requires_grad is False - assert ctc_loss.requires_grad is False - assert att_loss.requires_grad is False - - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - - tot_ctc_loss += ctc_loss.detach().cpu().item() - tot_att_loss += att_loss.detach().cpu().item() - - tot_frames += params.valid_frames + tot_loss = tot_loss + loss_info if world_size > 1: - s = torch.tensor( - [tot_loss, tot_ctc_loss, tot_att_loss, tot_frames], - device=loss.device, - ) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_ctc_loss = s[1] - tot_att_loss = s[2] - tot_frames = s[3] + tot_loss.reduce(loss.device) - params.valid_loss = tot_loss / tot_frames - params.valid_ctc_loss = tot_ctc_loss / tot_frames - params.valid_att_loss = tot_att_loss / tot_frames - - if params.valid_loss < params.best_valid_loss: + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss + params.best_valid_loss = loss_value + + return tot_loss def train_one_epoch( @@ -471,24 +454,21 @@ def train_one_epoch( """ model.train() - tot_loss = 0.0 # sum of losses over all batches - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 + tot_loss = MetricsTracker() - tot_frames = 0.0 # sum of frames over all batches - params.tot_loss = 0.0 - params.tot_frames = 0.0 for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss, ctc_loss, att_loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True, ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -498,75 +478,26 @@ def train_one_epoch( clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - loss_cpu = loss.detach().cpu().item() - ctc_loss_cpu = ctc_loss.detach().cpu().item() - att_loss_cpu = att_loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_ctc_loss += ctc_loss_cpu - tot_att_loss += att_loss_cpu - - params.tot_frames += params.train_frames - params.tot_loss += loss_cpu - - tot_avg_loss = tot_loss / tot_frames - tot_avg_ctc_loss = tot_ctc_loss / tot_frames - tot_avg_att_loss = tot_att_loss / tot_frames - if batch_idx % params.log_interval == 0: logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, " - f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, " - f"total avg att loss: {tot_avg_att_loss:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" ) + if batch_idx % 10 == 0: + if tb_writer is not None: - tb_writer.add_scalar( - "train/current_ctc_loss", - ctc_loss_cpu / params.train_frames, - params.batch_idx_train, + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train ) - tb_writer.add_scalar( - "train/current_att_loss", - att_loss_cpu / params.train_frames, - params.batch_idx_train, + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train ) - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/tot_avg_ctc_loss", - tot_avg_ctc_loss, - params.batch_idx_train, - ) - - tb_writer.add_scalar( - "train/tot_avg_att_loss", - tot_avg_att_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, - ) - if batch_idx > 0 and batch_idx % params.reset_interval == 0: - tot_loss = 0.0 # sum of losses over all batches - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 - - tot_frames = 0.0 # sum of frames over all batches if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( + logging.info("Computing validation loss") + valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, @@ -574,33 +505,14 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, " - f"valid ctc loss {params.valid_ctc_loss:.4f}," - f"valid att loss {params.valid_att_loss:.4f}," - f"valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") if tb_writer is not None: - tb_writer.add_scalar( - "train/valid_ctc_loss", - params.valid_ctc_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/valid_att_loss", - params.valid_att_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/valid_loss", - params.valid_loss, - params.batch_idx_train, + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train ) - params.train_loss = params.tot_loss / params.tot_frames - + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch params.best_train_loss = params.train_loss diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 8aa972806..b536cb472 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -57,13 +57,13 @@ log() { log "dl_dir: $dl_dir" if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "stage -1: Download LM" + log "Stage -1: Download LM" [ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm ./local/download_lm.py --out-dir=$dl_dir/lm fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "stage 0: Download data" + log "Stage 0: Download data" # If you have pre-downloaded it to /path/to/LibriSpeech, # you can create a symlink @@ -126,7 +126,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "State 6: Prepare BPE based lang" + log "Stage 6: Prepare BPE based lang" for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 695ee5130..51a486e07 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -20,14 +21,15 @@ import argparse import logging from pathlib import Path from shutil import copyfile -from typing import Optional +from typing import Optional, Tuple import k2 import torch -import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim +from torch import Tensor + from asr_datamodule import LibriSpeechAsrDataModule from lhotse.utils import fix_random_seed from model import TdnnLstm @@ -43,6 +45,7 @@ from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + MetricsTracker, encode_supervisions, setup_logger, str2bool, @@ -267,7 +270,7 @@ def compute_loss( batch: dict, graph_compiler: CtcTrainingGraphCompiler, is_training: bool, -): +) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -324,13 +327,11 @@ def compute_loss( assert loss.requires_grad == is_training - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() - else: - params.valid_frames = supervision_segments[:, 2].sum().item() + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["loss"] = loss.detach().cpu().item() - return loss + return loss, info def compute_validation_loss( @@ -339,16 +340,16 @@ def compute_validation_loss( graph_compiler: CtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> None: +) -> MetricsTracker: """Run the validation process. The validation loss is saved in `params.valid_loss`. """ model.eval() - tot_loss = 0.0 - tot_frames = 0.0 + tot_loss = MetricsTracker() + for batch_idx, batch in enumerate(valid_dl): - loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, @@ -357,22 +358,18 @@ def compute_validation_loss( ) assert loss.requires_grad is False - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - tot_frames += params.valid_frames + tot_loss = tot_loss + loss_info if world_size > 1: - s = torch.tensor([tot_loss, tot_frames], device=loss.device) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_frames = s[1] + tot_loss.reduce(loss.device) - params.valid_loss = tot_loss / tot_frames + loss_value = tot_loss["loss"] / tot_loss["frames"] - if params.valid_loss < params.best_valid_loss: + if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss + params.best_valid_loss = loss_value + + return tot_loss def train_one_epoch( @@ -411,67 +408,45 @@ def train_one_epoch( """ model.train() - tot_loss = 0.0 # reset after params.reset_interval of batches - tot_frames = 0.0 # reset after params.reset_interval of batches - - params.tot_loss = 0.0 - params.tot_frames = 0.0 + tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True, ) - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. + # summary stats. + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info optimizer.zero_grad() loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - loss_cpu = loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_avg_loss = tot_loss / tot_frames - - params.tot_frames += params.train_frames - params.tot_loss += loss_cpu - if batch_idx % params.log_interval == 0: logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" ) + if batch_idx % 10 == 0: + if tb_writer is not None: - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train ) - - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train ) - if batch_idx > 0 and batch_idx % params.reset_interval == 0: - tot_loss = 0 - tot_frames = 0 - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( + valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, @@ -479,13 +454,16 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, + "train/valid_", + params.batch_idx_train, + ) - params.train_loss = params.tot_loss / params.tot_frames + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch diff --git a/egs/yesno/ASR/prepare.sh b/egs/yesno/ASR/prepare.sh index 9a0cc48bb..8fcee0290 100755 --- a/egs/yesno/ASR/prepare.sh +++ b/egs/yesno/ASR/prepare.sh @@ -24,7 +24,7 @@ log() { log "dl_dir: $dl_dir" if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "stage 0: Download data" + log "Stage 0: Download data" mkdir -p $dl_dir if [ ! -f $dl_dir/waves_yesno/.completed ]; then diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index 0f5506d38..6cc511a28 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -4,14 +4,14 @@ import argparse import logging from pathlib import Path from shutil import copyfile -from typing import Optional +from typing import Optional, Tuple import k2 import torch -import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim +from torch import Tensor from asr_datamodule import YesNoAsrDataModule from lhotse.utils import fix_random_seed from model import Tdnn @@ -24,7 +24,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, setup_logger, str2bool +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool def get_parser(): @@ -122,6 +122,8 @@ def get_params() -> AttributeDict: - valid_interval: Run validation if batch_idx % valid_interval` is 0 + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + - beam_size: It is used in k2.ctc_loss - reduction: It is used in k2.ctc_loss @@ -142,6 +144,7 @@ def get_params() -> AttributeDict: "best_valid_epoch": -1, "batch_idx_train": 0, "log_interval": 10, + "reset_interval": 20, "valid_interval": 10, "beam_size": 10, "reduction": "sum", @@ -245,7 +248,7 @@ def compute_loss( batch: dict, graph_compiler: CtcTrainingGraphCompiler, is_training: bool, -): +) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -305,13 +308,11 @@ def compute_loss( assert loss.requires_grad == is_training - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() - else: - params.valid_frames = supervision_segments[:, 2].sum().item() + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["loss"] = loss.detach().cpu().item() - return loss + return loss, info def compute_validation_loss( @@ -320,16 +321,16 @@ def compute_validation_loss( graph_compiler: CtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> None: +) -> MetricsTracker: """Run the validation process. The validation loss is saved in `params.valid_loss`. """ model.eval() - tot_loss = 0.0 - tot_frames = 0.0 + tot_loss = MetricsTracker() + for batch_idx, batch in enumerate(valid_dl): - loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, @@ -338,22 +339,18 @@ def compute_validation_loss( ) assert loss.requires_grad is False - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - tot_frames += params.valid_frames + tot_loss = tot_loss + loss_info if world_size > 1: - s = torch.tensor([tot_loss, tot_frames], device=loss.device) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_frames = s[1] + tot_loss.reduce(loss.device) - params.valid_loss = tot_loss / tot_frames + loss_value = tot_loss["loss"] / tot_loss["frames"] - if params.valid_loss < params.best_valid_loss: + if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss + params.best_valid_loss = loss_value + + return tot_loss def train_one_epoch( @@ -392,57 +389,45 @@ def train_one_epoch( """ model.train() - tot_loss = 0.0 # sum of losses over all batches - tot_frames = 0.0 # sum of frames over all batches + tot_loss = MetricsTracker() + for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True, ) - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. + # summary stats. + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info optimizer.zero_grad() loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - loss_cpu = loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_avg_loss = tot_loss / tot_frames - if batch_idx % params.log_interval == 0: logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" ) + if batch_idx % 10 == 0: if tb_writer is not None: - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train ) - - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( + valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, @@ -450,19 +435,16 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") if tb_writer is not None: - tb_writer.add_scalar( - "train/valid_loss", - params.valid_loss, + valid_info.write_summary( + tb_writer, + "train/valid_", params.batch_idx_train, ) - params.train_loss = tot_loss / tot_frames + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch diff --git a/icefall/utils.py b/icefall/utils.py index 23b4dd6c7..66aa5c601 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1,4 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) # # See ../../LICENSE for clarification regarding multiple authors # @@ -17,6 +18,7 @@ import argparse import logging +import collections import os import subprocess from collections import defaultdict @@ -29,6 +31,7 @@ import k2 import kaldialign import torch import torch.distributed as dist +from torch.utils.tensorboard import SummaryWriter Pathlike = Union[str, Path] @@ -166,8 +169,8 @@ def encode_supervisions( supervisions: dict, subsampling_factor: int ) -> Tuple[torch.Tensor, List[str]]: """ - Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor, - and a list of transcription strings. + Encodes Lhotse's ``batch["supervisions"]`` dict into + a pair of torch Tensor, and a list of transcription strings. The supervision tensor has shape ``(batch_size, 3)``. Its second dimension contains information about sequence index [0], @@ -272,13 +275,13 @@ def write_error_stats( Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 reference words (2337 correct) - - The difference between the reference transcript and predicted results. + - The difference between the reference transcript and predicted result. An instance is given below:: THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES - The above example shows that the reference word is `EDISON`, but it is - predicted to `ADDISON` (a substitution error). + The above example shows that the reference word is `EDISON`, + but it is predicted to `ADDISON` (a substitution error). Another example is:: @@ -419,3 +422,76 @@ def write_error_stats( print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) return float(tot_err_rate) + + +class MetricsTracker(collections.defaultdict): + def __init__(self): + # Passing the type 'int' to the base-class constructor + # makes undefined items default to int() which is zero. + # This class will play a role as metrics tracker. + # It can record many metrics, including but not limited to loss. + super(MetricsTracker, self).__init__(int) + + def __add__(self, other: "MetricsTracker") -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v + for k, v in other.items(): + ans[k] = ans[k] + v + return ans + + def __mul__(self, alpha: float) -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v * alpha + return ans + + def __str__(self) -> str: + ans = "" + for k, v in self.norm_items(): + norm_value = "%.4g" % v + ans += str(k) + "=" + str(norm_value) + ", " + frames = str(self["frames"]) + ans += "over " + frames + " frames." + return ans + + def norm_items(self) -> List[Tuple[str, float]]: + """ + Returns a list of pairs, like: + [('ctc_loss', 0.1), ('att_loss', 0.07)] + """ + num_frames = self["frames"] if "frames" in self else 1 + ans = [] + for k, v in self.items(): + if k != "frames": + norm_value = float(v) / num_frames + ans.append((k, norm_value)) + return ans + + def reduce(self, device): + """ + Reduce using torch.distributed, which I believe ensures that + all processes get the total. + """ + keys = sorted(self.keys()) + s = torch.tensor([float(self[k]) for k in keys], device=device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + for k, v in zip(keys, s.cpu().tolist()): + self[k] = v + + def write_summary( + self, + tb_writer: SummaryWriter, + prefix: str, + batch_idx: int, + ) -> None: + """Add logging information to a TensorBoard writer. + + Args: + tb_writer: a TensorBoard writer + prefix: a prefix for the name of the loss, e.g. "train/valid_", + or "train/current_" + batch_idx: The current batch index, used as the x-axis of the plot. + """ + for k, v in self.norm_items(): + tb_writer.add_scalar(prefix + k, v, batch_idx)