diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 1870cb572..3e1049fbf 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -46,7 +46,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - LossRecord, + MetricsTracker, encode_supervisions, setup_logger, str2bool, @@ -291,7 +291,7 @@ def compute_loss( batch: dict, graph_compiler: BpeCtcTrainingGraphCompiler, is_training: bool, -) -> Tuple[Tensor, LossRecord]: +) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -373,7 +373,7 @@ def compute_loss( assert loss.requires_grad == is_training - info = LossRecord() + info = MetricsTracker() info["frames"] = supervision_segments[:, 2].sum().item() info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.att_rate != 0.0: @@ -390,11 +390,11 @@ def compute_validation_loss( graph_compiler: BpeCtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> LossRecord: +) -> MetricsTracker: """Run the validation process.""" model.eval() - tot_loss = LossRecord() + tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): loss, loss_info = compute_loss( @@ -454,7 +454,7 @@ def train_one_epoch( """ model.train() - tot_loss = LossRecord() + tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 07292fe5c..51a486e07 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -45,7 +45,7 @@ from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - LossRecord, + MetricsTracker, encode_supervisions, setup_logger, str2bool, @@ -270,7 +270,7 @@ def compute_loss( batch: dict, graph_compiler: CtcTrainingGraphCompiler, is_training: bool, -) -> Tuple[Tensor, LossRecord]: +) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -327,7 +327,7 @@ def compute_loss( assert loss.requires_grad == is_training - info = LossRecord() + info = MetricsTracker() info["frames"] = supervision_segments[:, 2].sum().item() info["loss"] = loss.detach().cpu().item() @@ -340,13 +340,13 @@ def compute_validation_loss( graph_compiler: CtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> LossRecord: +) -> MetricsTracker: """Run the validation process. The validation loss is saved in `params.valid_loss`. """ model.eval() - tot_loss = LossRecord() + tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): loss, loss_info = compute_loss( @@ -408,7 +408,7 @@ def train_one_epoch( """ model.train() - tot_loss = LossRecord() + tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index b2a213b2b..6cc511a28 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -24,7 +24,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, LossRecord, setup_logger, str2bool +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool def get_parser(): @@ -248,7 +248,7 @@ def compute_loss( batch: dict, graph_compiler: CtcTrainingGraphCompiler, is_training: bool, -) -> Tuple[Tensor, LossRecord]: +) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -308,7 +308,7 @@ def compute_loss( assert loss.requires_grad == is_training - info = LossRecord() + info = MetricsTracker() info["frames"] = supervision_segments[:, 2].sum().item() info["loss"] = loss.detach().cpu().item() @@ -321,13 +321,13 @@ def compute_validation_loss( graph_compiler: CtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> LossRecord: +) -> MetricsTracker: """Run the validation process. The validation loss is saved in `params.valid_loss`. """ model.eval() - tot_loss = LossRecord() + tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): loss, loss_info = compute_loss( @@ -389,7 +389,7 @@ def train_one_epoch( """ model.train() - tot_loss = LossRecord() + tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 diff --git a/icefall/utils.py b/icefall/utils.py index 7c588a691..66aa5c601 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -424,22 +424,24 @@ def write_error_stats( return float(tot_err_rate) -class LossRecord(collections.defaultdict): +class MetricsTracker(collections.defaultdict): def __init__(self): # Passing the type 'int' to the base-class constructor # makes undefined items default to int() which is zero. - super(LossRecord, self).__init__(int) + # This class will play a role as metrics tracker. + # It can record many metrics, including but not limited to loss. + super(MetricsTracker, self).__init__(int) - def __add__(self, other: "LossRecord") -> "LossRecord": - ans = LossRecord() + def __add__(self, other: "MetricsTracker") -> "MetricsTracker": + ans = MetricsTracker() for k, v in self.items(): ans[k] = v for k, v in other.items(): ans[k] = ans[k] + v return ans - def __mul__(self, alpha: float) -> "LossRecord": - ans = LossRecord() + def __mul__(self, alpha: float) -> "MetricsTracker": + ans = MetricsTracker() for k, v in self.items(): ans[k] = v * alpha return ans