diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 48c0e683d..66123e718 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -603,6 +603,14 @@ def compute_loss( (feature_lens // params.subsampling_factor).sum().item() ) + info["utterances"] = feature.size(0) + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utt_duration"] = feature_lens.sum().item() + # padding proportion of each utterance + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() diff --git a/icefall/utils.py b/icefall/utils.py index b38574f0c..b495d6b5a 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -521,13 +521,20 @@ class MetricsTracker(collections.defaultdict): return ans def __str__(self) -> str: - ans = "" + ans_frames = "" + ans_utterances = "" for k, v in self.norm_items(): norm_value = "%.4g" % v - ans += str(k) + "=" + str(norm_value) + ", " + if "utt_" not in k: + ans_frames += str(k) + "=" + str(norm_value) + ", " + else: + ans_utterances += str(k) + "=" + str(norm_value) + ", " frames = "%.2f" % self["frames"] - ans += "over " + str(frames) + " frames." - return ans + ans_frames += "over " + str(frames) + " frames; " + utterances = "%.2f" % self["utterances"] + ans_utterances += "over " + str(utterances) + " utterances." + + return ans_frames + ans_utterances def norm_items(self) -> List[Tuple[str, float]]: """ @@ -535,11 +542,17 @@ class MetricsTracker(collections.defaultdict): [('ctc_loss', 0.1), ('att_loss', 0.07)] """ num_frames = self["frames"] if "frames" in self else 1 + num_utterances = self["utterances"] if "utterances" in self else 1 ans = [] for k, v in self.items(): - if k != "frames": - norm_value = float(v) / num_frames - ans.append((k, norm_value)) + if k == "frames" or k == "utterances": + continue + norm_value = ( + float(v) / num_frames + if "utt_" not in k + else float(v) / num_utterances + ) + ans.append((k, norm_value)) return ans def reduce(self, device):