Change the name "LossRecord" to "MetricsTracker"

This commit is contained in:
Mingshuang Luo 2021-09-30 10:05:55 +08:00
parent 9a1d76f7c3
commit ff44415313
4 changed files with 26 additions and 24 deletions

View File

@ -46,7 +46,7 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
LossRecord,
MetricsTracker,
encode_supervisions,
setup_logger,
str2bool,
@ -291,7 +291,7 @@ def compute_loss(
batch: dict,
graph_compiler: BpeCtcTrainingGraphCompiler,
is_training: bool,
) -> Tuple[Tensor, LossRecord]:
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
@ -373,7 +373,7 @@ def compute_loss(
assert loss.requires_grad == is_training
info = LossRecord()
info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item()
info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.att_rate != 0.0:
@ -390,11 +390,11 @@ def compute_validation_loss(
graph_compiler: BpeCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> LossRecord:
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = LossRecord()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
@ -454,7 +454,7 @@ def train_one_epoch(
"""
model.train()
tot_loss = LossRecord()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1

View File

@ -45,7 +45,7 @@ from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
LossRecord,
MetricsTracker,
encode_supervisions,
setup_logger,
str2bool,
@ -270,7 +270,7 @@ def compute_loss(
batch: dict,
graph_compiler: CtcTrainingGraphCompiler,
is_training: bool,
) -> Tuple[Tensor, LossRecord]:
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
@ -327,7 +327,7 @@ def compute_loss(
assert loss.requires_grad == is_training
info = LossRecord()
info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item()
info["loss"] = loss.detach().cpu().item()
@ -340,13 +340,13 @@ def compute_validation_loss(
graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> LossRecord:
) -> MetricsTracker:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = LossRecord()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
@ -408,7 +408,7 @@ def train_one_epoch(
"""
model.train()
tot_loss = LossRecord()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1

View File

@ -24,7 +24,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, LossRecord, setup_logger, str2bool
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser():
@ -248,7 +248,7 @@ def compute_loss(
batch: dict,
graph_compiler: CtcTrainingGraphCompiler,
is_training: bool,
) -> Tuple[Tensor, LossRecord]:
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
@ -308,7 +308,7 @@ def compute_loss(
assert loss.requires_grad == is_training
info = LossRecord()
info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item()
info["loss"] = loss.detach().cpu().item()
@ -321,13 +321,13 @@ def compute_validation_loss(
graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> LossRecord:
) -> MetricsTracker:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = LossRecord()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
@ -389,7 +389,7 @@ def train_one_epoch(
"""
model.train()
tot_loss = LossRecord()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1

View File

@ -424,22 +424,24 @@ def write_error_stats(
return float(tot_err_rate)
class LossRecord(collections.defaultdict):
class MetricsTracker(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
super(LossRecord, self).__init__(int)
# This class will play a role as metrics tracker.
# It can record many metrics, including but not limited to loss.
super(MetricsTracker, self).__init__(int)
def __add__(self, other: "LossRecord") -> "LossRecord":
ans = LossRecord()
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> "LossRecord":
ans = LossRecord()
def __mul__(self, alpha: float) -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v * alpha
return ans