Add the "rms-sort" diagnostics (#1851)

This commit is contained in:
Han Zhu 2024-12-30 15:27:05 +08:00 committed by GitHub
parent ad966fb81d
commit 57e9f2a8db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -63,12 +63,22 @@ def get_tensor_stats(
"rms" -> square before summing, we'll take sqrt later
"value" -> just sum x itself
"max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing
"rms-sort" -> this is a bit different than the others, it's based on computing the
rms over the specified dim and returning percentiles of the result (11 of them).
Returns:
stats: a Tensor of shape (x.shape[dim],).
count: an integer saying how many items were counted in each element
of stats.
"""
if stats_type == "rms-sort":
rms = (x**2).mean(dim=dim).sqrt()
rms = rms.flatten()
rms = rms.sort()[0]
rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)]
count = 1.0
return rms, count
count = x.numel() // x.shape[dim]
if stats_type == "eigs":
@ -164,7 +174,17 @@ class TensorDiagnostic(object):
for dim in range(ndim):
this_dim_stats = self.stats[dim]
if ndim > 1:
stats_types = ["abs", "max", "min", "positive", "value", "rms"]
# rms-sort is different from the others, it's based on summing over just this
# dim, then sorting and returning the percentiles.
stats_types = [
"abs",
"max",
"min",
"positive",
"value",
"rms",
"rms-sort",
]
if x.shape[dim] <= self.opts.max_eig_dim:
stats_types.append("eigs")
else: