mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
normaliza att_loss with symbols numbers
This commit is contained in:
parent
2fc7535de9
commit
2fc4ce2751
@ -711,6 +711,7 @@ def compute_loss(
|
|||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
|
info["symbols"] = sum([len(t) for t in token_ids]) # used to normalize att_loss
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||||
info["att_loss"] = att_loss.detach().cpu().item()
|
info["att_loss"] = att_loss.detach().cpu().item()
|
||||||
|
@ -844,10 +844,13 @@ class MetricsTracker(collections.defaultdict):
|
|||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
ans_frames = ""
|
ans_frames = ""
|
||||||
|
ans_symbols = ""
|
||||||
ans_utterances = ""
|
ans_utterances = ""
|
||||||
for k, v in self.norm_items():
|
for k, v in self.norm_items():
|
||||||
norm_value = "%.4g" % v
|
norm_value = "%.4g" % v
|
||||||
if "utt_" not in k:
|
if k == "att_loss":
|
||||||
|
ans_symbols += str(k) + "=" + str(norm_value) + ", "
|
||||||
|
elif "utt_" not in k:
|
||||||
ans_frames += str(k) + "=" + str(norm_value) + ", "
|
ans_frames += str(k) + "=" + str(norm_value) + ", "
|
||||||
else:
|
else:
|
||||||
ans_utterances += str(k) + "=" + str(norm_value)
|
ans_utterances += str(k) + "=" + str(norm_value)
|
||||||
@ -859,11 +862,13 @@ class MetricsTracker(collections.defaultdict):
|
|||||||
raise ValueError(f"Unexpected key: {k}")
|
raise ValueError(f"Unexpected key: {k}")
|
||||||
frames = "%.2f" % self["frames"]
|
frames = "%.2f" % self["frames"]
|
||||||
ans_frames += "over " + str(frames) + " frames. "
|
ans_frames += "over " + str(frames) + " frames. "
|
||||||
|
symbols = "%.2f" % self["symbols"]
|
||||||
|
ans_symbols += "over " + str(symbols) + " symbols. "
|
||||||
if ans_utterances != "":
|
if ans_utterances != "":
|
||||||
utterances = "%.2f" % self["utterances"]
|
utterances = "%.2f" % self["utterances"]
|
||||||
ans_utterances += "over " + str(utterances) + " utterances."
|
ans_utterances += "over " + str(utterances) + " utterances."
|
||||||
|
|
||||||
return ans_frames + ans_utterances
|
return ans_frames + ans_symbols + ans_utterances
|
||||||
|
|
||||||
def norm_items(self) -> List[Tuple[str, float]]:
|
def norm_items(self) -> List[Tuple[str, float]]:
|
||||||
"""
|
"""
|
||||||
@ -871,14 +876,18 @@ class MetricsTracker(collections.defaultdict):
|
|||||||
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
||||||
"""
|
"""
|
||||||
num_frames = self["frames"] if "frames" in self else 1
|
num_frames = self["frames"] if "frames" in self else 1
|
||||||
|
num_symbols = self["symbols"] if "symbols" in self else 1
|
||||||
num_utterances = self["utterances"] if "utterances" in self else 1
|
num_utterances = self["utterances"] if "utterances" in self else 1
|
||||||
ans = []
|
ans = []
|
||||||
for k, v in self.items():
|
for k, v in self.items():
|
||||||
if k == "frames" or k == "utterances":
|
if k == "frames" or k == "symbols" or k == "utterances":
|
||||||
continue
|
continue
|
||||||
norm_value = (
|
if k == "att_loss":
|
||||||
float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
|
norm_value = float(v) / num_symbols
|
||||||
)
|
elif "utt_" in k:
|
||||||
|
norm_value = float(v) / num_utterances
|
||||||
|
else:
|
||||||
|
norm_value = float(v) / num_frames
|
||||||
ans.append((k, norm_value))
|
ans.append((k, norm_value))
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user