Change the name "LossRecord" to "MetricsTracker"

This commit is contained in:
Mingshuang Luo 2021-09-30 10:05:55 +08:00
parent 9a1d76f7c3
commit ff44415313
4 changed files with 26 additions and 24 deletions

View File

@ -46,7 +46,7 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
LossRecord, MetricsTracker,
encode_supervisions, encode_supervisions,
setup_logger, setup_logger,
str2bool, str2bool,
@ -291,7 +291,7 @@ def compute_loss(
batch: dict, batch: dict,
graph_compiler: BpeCtcTrainingGraphCompiler, graph_compiler: BpeCtcTrainingGraphCompiler,
is_training: bool, is_training: bool,
) -> Tuple[Tensor, LossRecord]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -373,7 +373,7 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = LossRecord() info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item() info["frames"] = supervision_segments[:, 2].sum().item()
info["ctc_loss"] = ctc_loss.detach().cpu().item() info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.att_rate != 0.0: if params.att_rate != 0.0:
@ -390,11 +390,11 @@ def compute_validation_loss(
graph_compiler: BpeCtcTrainingGraphCompiler, graph_compiler: BpeCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> LossRecord: ) -> MetricsTracker:
"""Run the validation process.""" """Run the validation process."""
model.eval() model.eval()
tot_loss = LossRecord() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
@ -454,7 +454,7 @@ def train_one_epoch(
""" """
model.train() model.train()
tot_loss = LossRecord() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1

View File

@ -45,7 +45,7 @@ from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
LossRecord, MetricsTracker,
encode_supervisions, encode_supervisions,
setup_logger, setup_logger,
str2bool, str2bool,
@ -270,7 +270,7 @@ def compute_loss(
batch: dict, batch: dict,
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
is_training: bool, is_training: bool,
) -> Tuple[Tensor, LossRecord]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -327,7 +327,7 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = LossRecord() info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item() info["frames"] = supervision_segments[:, 2].sum().item()
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -340,13 +340,13 @@ def compute_validation_loss(
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> LossRecord: ) -> MetricsTracker:
"""Run the validation process. The validation loss """Run the validation process. The validation loss
is saved in `params.valid_loss`. is saved in `params.valid_loss`.
""" """
model.eval() model.eval()
tot_loss = LossRecord() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
@ -408,7 +408,7 @@ def train_one_epoch(
""" """
model.train() model.train()
tot_loss = LossRecord() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1

View File

@ -24,7 +24,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, LossRecord, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser(): def get_parser():
@ -248,7 +248,7 @@ def compute_loss(
batch: dict, batch: dict,
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
is_training: bool, is_training: bool,
) -> Tuple[Tensor, LossRecord]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -308,7 +308,7 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = LossRecord() info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item() info["frames"] = supervision_segments[:, 2].sum().item()
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -321,13 +321,13 @@ def compute_validation_loss(
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> LossRecord: ) -> MetricsTracker:
"""Run the validation process. The validation loss """Run the validation process. The validation loss
is saved in `params.valid_loss`. is saved in `params.valid_loss`.
""" """
model.eval() model.eval()
tot_loss = LossRecord() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
@ -389,7 +389,7 @@ def train_one_epoch(
""" """
model.train() model.train()
tot_loss = LossRecord() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1

View File

@ -424,22 +424,24 @@ def write_error_stats(
return float(tot_err_rate) return float(tot_err_rate)
class LossRecord(collections.defaultdict): class MetricsTracker(collections.defaultdict):
def __init__(self): def __init__(self):
# Passing the type 'int' to the base-class constructor # Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero. # makes undefined items default to int() which is zero.
super(LossRecord, self).__init__(int) # This class will play a role as metrics tracker.
# It can record many metrics, including but not limited to loss.
super(MetricsTracker, self).__init__(int)
def __add__(self, other: "LossRecord") -> "LossRecord": def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
ans = LossRecord() ans = MetricsTracker()
for k, v in self.items(): for k, v in self.items():
ans[k] = v ans[k] = v
for k, v in other.items(): for k, v in other.items():
ans[k] = ans[k] + v ans[k] = ans[k] + v
return ans return ans
def __mul__(self, alpha: float) -> "LossRecord": def __mul__(self, alpha: float) -> "MetricsTracker":
ans = LossRecord() ans = MetricsTracker()
for k, v in self.items(): for k, v in self.items():
ans[k] = v * alpha ans[k] = v * alpha
return ans return ans