mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
Change the name "LossRecord" to "MetricsTracker"
This commit is contained in:
parent
9a1d76f7c3
commit
ff44415313
@ -46,7 +46,7 @@ from icefall.dist import cleanup_dist, setup_dist
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
LossRecord,
|
MetricsTracker,
|
||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
@ -291,7 +291,7 @@ def compute_loss(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
) -> Tuple[Tensor, LossRecord]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute CTC loss given the model and its inputs.
|
||||||
|
|
||||||
@ -373,7 +373,7 @@ def compute_loss(
|
|||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = LossRecord()
|
info = MetricsTracker()
|
||||||
info["frames"] = supervision_segments[:, 2].sum().item()
|
info["frames"] = supervision_segments[:, 2].sum().item()
|
||||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||||
if params.att_rate != 0.0:
|
if params.att_rate != 0.0:
|
||||||
@ -390,11 +390,11 @@ def compute_validation_loss(
|
|||||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
) -> LossRecord:
|
) -> MetricsTracker:
|
||||||
"""Run the validation process."""
|
"""Run the validation process."""
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
tot_loss = LossRecord()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
@ -454,7 +454,7 @@ def train_one_epoch(
|
|||||||
"""
|
"""
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
tot_loss = LossRecord()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
|
@ -45,7 +45,7 @@ from icefall.graph_compiler import CtcTrainingGraphCompiler
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
LossRecord,
|
MetricsTracker,
|
||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
@ -270,7 +270,7 @@ def compute_loss(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
graph_compiler: CtcTrainingGraphCompiler,
|
graph_compiler: CtcTrainingGraphCompiler,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
) -> Tuple[Tensor, LossRecord]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute CTC loss given the model and its inputs.
|
||||||
|
|
||||||
@ -327,7 +327,7 @@ def compute_loss(
|
|||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = LossRecord()
|
info = MetricsTracker()
|
||||||
info["frames"] = supervision_segments[:, 2].sum().item()
|
info["frames"] = supervision_segments[:, 2].sum().item()
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
|
|
||||||
@ -340,13 +340,13 @@ def compute_validation_loss(
|
|||||||
graph_compiler: CtcTrainingGraphCompiler,
|
graph_compiler: CtcTrainingGraphCompiler,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
) -> LossRecord:
|
) -> MetricsTracker:
|
||||||
"""Run the validation process. The validation loss
|
"""Run the validation process. The validation loss
|
||||||
is saved in `params.valid_loss`.
|
is saved in `params.valid_loss`.
|
||||||
"""
|
"""
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
tot_loss = LossRecord()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
@ -408,7 +408,7 @@ def train_one_epoch(
|
|||||||
"""
|
"""
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
tot_loss = LossRecord()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
|
@ -24,7 +24,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, LossRecord, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -248,7 +248,7 @@ def compute_loss(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
graph_compiler: CtcTrainingGraphCompiler,
|
graph_compiler: CtcTrainingGraphCompiler,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
) -> Tuple[Tensor, LossRecord]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute CTC loss given the model and its inputs.
|
||||||
|
|
||||||
@ -308,7 +308,7 @@ def compute_loss(
|
|||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = LossRecord()
|
info = MetricsTracker()
|
||||||
info["frames"] = supervision_segments[:, 2].sum().item()
|
info["frames"] = supervision_segments[:, 2].sum().item()
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
|
|
||||||
@ -321,13 +321,13 @@ def compute_validation_loss(
|
|||||||
graph_compiler: CtcTrainingGraphCompiler,
|
graph_compiler: CtcTrainingGraphCompiler,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
) -> LossRecord:
|
) -> MetricsTracker:
|
||||||
"""Run the validation process. The validation loss
|
"""Run the validation process. The validation loss
|
||||||
is saved in `params.valid_loss`.
|
is saved in `params.valid_loss`.
|
||||||
"""
|
"""
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
tot_loss = LossRecord()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
@ -389,7 +389,7 @@ def train_one_epoch(
|
|||||||
"""
|
"""
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
tot_loss = LossRecord()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
|
@ -424,22 +424,24 @@ def write_error_stats(
|
|||||||
return float(tot_err_rate)
|
return float(tot_err_rate)
|
||||||
|
|
||||||
|
|
||||||
class LossRecord(collections.defaultdict):
|
class MetricsTracker(collections.defaultdict):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Passing the type 'int' to the base-class constructor
|
# Passing the type 'int' to the base-class constructor
|
||||||
# makes undefined items default to int() which is zero.
|
# makes undefined items default to int() which is zero.
|
||||||
super(LossRecord, self).__init__(int)
|
# This class will play a role as metrics tracker.
|
||||||
|
# It can record many metrics, including but not limited to loss.
|
||||||
|
super(MetricsTracker, self).__init__(int)
|
||||||
|
|
||||||
def __add__(self, other: "LossRecord") -> "LossRecord":
|
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
|
||||||
ans = LossRecord()
|
ans = MetricsTracker()
|
||||||
for k, v in self.items():
|
for k, v in self.items():
|
||||||
ans[k] = v
|
ans[k] = v
|
||||||
for k, v in other.items():
|
for k, v in other.items():
|
||||||
ans[k] = ans[k] + v
|
ans[k] = ans[k] + v
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def __mul__(self, alpha: float) -> "LossRecord":
|
def __mul__(self, alpha: float) -> "MetricsTracker":
|
||||||
ans = LossRecord()
|
ans = MetricsTracker()
|
||||||
for k, v in self.items():
|
for k, v in self.items():
|
||||||
ans[k] = v * alpha
|
ans[k] = v * alpha
|
||||||
return ans
|
return ans
|
||||||
|
Loading…
x
Reference in New Issue
Block a user