Add stddev stats in diagnostics.py

This commit is contained in:
Daniel Povey 2022-12-06 19:26:59 +08:00
parent 3f82ee0783
commit 6845da4351

View File

@ -211,6 +211,19 @@ class TensorDiagnostic(object):
print(f"Warning: the stats of {self.name} is None.")
return
for dim, this_dim_stats in enumerate(self.stats):
if "rms" in this_dim_stats and "value" in this_dim_stats:
# produce "stddev" stats, which is centered RMS.
rms_stats_list = this_dim_stats["rms"]
value_stats_list = this_dim_stats["value"]
if len(rms_stats_list) == len(value_stats_list):
stddev_stats_list = []
for r, v in zip(rms_stats_list, value_stats_list):
stddev_stats_list.append(
# r.count and v.count should be the same, but we don't check this.
TensorAndCount(r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20),
r.count))
this_dim_stats["stddev"] = stddev_stats_list
for stats_type, stats_list in this_dim_stats.items():
# stats_type could be "rms", "value", "abs", "eigs", "positive", "min" or "max".
# "stats_list" could be a list of TensorAndCount (one list per distinct tensor
@ -244,7 +257,7 @@ class TensorDiagnostic(object):
stats = eigs.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance
if stats_type == "rms":
if stats_type in [ "rms", "stddev" ]:
# we stored the square; after aggregation we need to take sqrt.
stats = stats.sqrt()
@ -269,7 +282,7 @@ class TensorDiagnostic(object):
ans = stats.tolist()
ans = ["%.2g" % x for x in ans]
ans = "[" + " ".join(ans) + "]"
if stats_type in [ "value", "rms", "eigs" ]:
if stats_type in [ "value", "rms", "stddev", "eigs" ]:
# This norm is useful because it is strictly less than the largest
# sqrt(eigenvalue) of the variance, which we print out, and shows,
# speaking in an approximate way, how much of that largest eigenvalue