mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Change the name "LossRecord" to "MetricsTracker"
This commit is contained in:
parent
9a1d76f7c3
commit
ff44415313
@ -46,7 +46,7 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
LossRecord,
|
||||
MetricsTracker,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
@ -291,7 +291,7 @@ def compute_loss(
|
||||
batch: dict,
|
||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||
is_training: bool,
|
||||
) -> Tuple[Tensor, LossRecord]:
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute CTC loss given the model and its inputs.
|
||||
|
||||
@ -373,7 +373,7 @@ def compute_loss(
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = LossRecord()
|
||||
info = MetricsTracker()
|
||||
info["frames"] = supervision_segments[:, 2].sum().item()
|
||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||
if params.att_rate != 0.0:
|
||||
@ -390,11 +390,11 @@ def compute_validation_loss(
|
||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> LossRecord:
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
|
||||
tot_loss = LossRecord()
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss, loss_info = compute_loss(
|
||||
@ -454,7 +454,7 @@ def train_one_epoch(
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = LossRecord()
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
|
@ -45,7 +45,7 @@ from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
LossRecord,
|
||||
MetricsTracker,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
@ -270,7 +270,7 @@ def compute_loss(
|
||||
batch: dict,
|
||||
graph_compiler: CtcTrainingGraphCompiler,
|
||||
is_training: bool,
|
||||
) -> Tuple[Tensor, LossRecord]:
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute CTC loss given the model and its inputs.
|
||||
|
||||
@ -327,7 +327,7 @@ def compute_loss(
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = LossRecord()
|
||||
info = MetricsTracker()
|
||||
info["frames"] = supervision_segments[:, 2].sum().item()
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
@ -340,13 +340,13 @@ def compute_validation_loss(
|
||||
graph_compiler: CtcTrainingGraphCompiler,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> LossRecord:
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process. The validation loss
|
||||
is saved in `params.valid_loss`.
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
tot_loss = LossRecord()
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss, loss_info = compute_loss(
|
||||
@ -408,7 +408,7 @@ def train_one_epoch(
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = LossRecord()
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
|
@ -24,7 +24,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, LossRecord, setup_logger, str2bool
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -248,7 +248,7 @@ def compute_loss(
|
||||
batch: dict,
|
||||
graph_compiler: CtcTrainingGraphCompiler,
|
||||
is_training: bool,
|
||||
) -> Tuple[Tensor, LossRecord]:
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute CTC loss given the model and its inputs.
|
||||
|
||||
@ -308,7 +308,7 @@ def compute_loss(
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = LossRecord()
|
||||
info = MetricsTracker()
|
||||
info["frames"] = supervision_segments[:, 2].sum().item()
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
@ -321,13 +321,13 @@ def compute_validation_loss(
|
||||
graph_compiler: CtcTrainingGraphCompiler,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> LossRecord:
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process. The validation loss
|
||||
is saved in `params.valid_loss`.
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
tot_loss = LossRecord()
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss, loss_info = compute_loss(
|
||||
@ -389,7 +389,7 @@ def train_one_epoch(
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = LossRecord()
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
|
@ -424,22 +424,24 @@ def write_error_stats(
|
||||
return float(tot_err_rate)
|
||||
|
||||
|
||||
class LossRecord(collections.defaultdict):
|
||||
class MetricsTracker(collections.defaultdict):
|
||||
def __init__(self):
|
||||
# Passing the type 'int' to the base-class constructor
|
||||
# makes undefined items default to int() which is zero.
|
||||
super(LossRecord, self).__init__(int)
|
||||
# This class will play a role as metrics tracker.
|
||||
# It can record many metrics, including but not limited to loss.
|
||||
super(MetricsTracker, self).__init__(int)
|
||||
|
||||
def __add__(self, other: "LossRecord") -> "LossRecord":
|
||||
ans = LossRecord()
|
||||
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
|
||||
ans = MetricsTracker()
|
||||
for k, v in self.items():
|
||||
ans[k] = v
|
||||
for k, v in other.items():
|
||||
ans[k] = ans[k] + v
|
||||
return ans
|
||||
|
||||
def __mul__(self, alpha: float) -> "LossRecord":
|
||||
ans = LossRecord()
|
||||
def __mul__(self, alpha: float) -> "MetricsTracker":
|
||||
ans = MetricsTracker()
|
||||
for k, v in self.items():
|
||||
ans[k] = v * alpha
|
||||
return ans
|
||||
|
Loading…
x
Reference in New Issue
Block a user