diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 2367f7171..96c085541 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -211,6 +211,19 @@ class TensorDiagnostic(object): print(f"Warning: the stats of {self.name} is None.") return for dim, this_dim_stats in enumerate(self.stats): + if "rms" in this_dim_stats and "value" in this_dim_stats: + # produce "stddev" stats, which is centered RMS. + rms_stats_list = this_dim_stats["rms"] + value_stats_list = this_dim_stats["value"] + if len(rms_stats_list) == len(value_stats_list): + stddev_stats_list = [] + for r, v in zip(rms_stats_list, value_stats_list): + stddev_stats_list.append( + # r.count and v.count should be the same, but we don't check this. + TensorAndCount(r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20), + r.count)) + this_dim_stats["stddev"] = stddev_stats_list + for stats_type, stats_list in this_dim_stats.items(): # stats_type could be "rms", "value", "abs", "eigs", "positive", "min" or "max". # "stats_list" could be a list of TensorAndCount (one list per distinct tensor @@ -244,7 +257,7 @@ class TensorDiagnostic(object): stats = eigs.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance - if stats_type == "rms": + if stats_type in [ "rms", "stddev" ]: # we stored the square; after aggregation we need to take sqrt. stats = stats.sqrt() @@ -269,7 +282,7 @@ class TensorDiagnostic(object): ans = stats.tolist() ans = ["%.2g" % x for x in ans] ans = "[" + " ".join(ans) + "]" - if stats_type in [ "value", "rms", "eigs" ]: + if stats_type in [ "value", "rms", "stddev", "eigs" ]: # This norm is useful because it is strictly less than the largest # sqrt(eigenvalue) of the variance, which we print out, and shows, # speaking in an approximate way, how much of that largest eigenvalue