diff --git a/egs/librispeech/ASR/zipformer_ctc_attn/train.py b/egs/librispeech/ASR/zipformer_ctc_attn/train.py index 27a9add0d..06d4a4df0 100755 --- a/egs/librispeech/ASR/zipformer_ctc_attn/train.py +++ b/egs/librispeech/ASR/zipformer_ctc_attn/train.py @@ -711,6 +711,7 @@ def compute_loss( info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. + info["symbols"] = sum([len(t) for t in token_ids]) # used to normalize att_loss info["loss"] = loss.detach().cpu().item() info["ctc_loss"] = ctc_loss.detach().cpu().item() info["att_loss"] = att_loss.detach().cpu().item() diff --git a/icefall/utils.py b/icefall/utils.py index 99e51a2a9..d51760c54 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -844,10 +844,13 @@ class MetricsTracker(collections.defaultdict): def __str__(self) -> str: ans_frames = "" + ans_symbols = "" ans_utterances = "" for k, v in self.norm_items(): norm_value = "%.4g" % v - if "utt_" not in k: + if k == "att_loss": + ans_symbols += str(k) + "=" + str(norm_value) + ", " + elif "utt_" not in k: ans_frames += str(k) + "=" + str(norm_value) + ", " else: ans_utterances += str(k) + "=" + str(norm_value) @@ -859,11 +862,13 @@ class MetricsTracker(collections.defaultdict): raise ValueError(f"Unexpected key: {k}") frames = "%.2f" % self["frames"] ans_frames += "over " + str(frames) + " frames. " + symbols = "%.2f" % self["symbols"] + ans_symbols += "over " + str(symbols) + " symbols. " if ans_utterances != "": utterances = "%.2f" % self["utterances"] ans_utterances += "over " + str(utterances) + " utterances." - return ans_frames + ans_utterances + return ans_frames + ans_symbols + ans_utterances def norm_items(self) -> List[Tuple[str, float]]: """ @@ -871,14 +876,18 @@ class MetricsTracker(collections.defaultdict): [('ctc_loss', 0.1), ('att_loss', 0.07)] """ num_frames = self["frames"] if "frames" in self else 1 + num_symbols = self["symbols"] if "symbols" in self else 1 num_utterances = self["utterances"] if "utterances" in self else 1 ans = [] for k, v in self.items(): - if k == "frames" or k == "utterances": + if k == "frames" or k == "symbols" or k == "utterances": continue - norm_value = ( - float(v) / num_frames if "utt_" not in k else float(v) / num_utterances - ) + if k == "att_loss": + norm_value = float(v) / num_symbols + elif "utt_" in k: + norm_value = float(v) / num_utterances + else: + norm_value = float(v) / num_frames ans.append((k, norm_value)) return ans