mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add the "rms-sort" diagnostics (#1851)
This commit is contained in:
parent
ad966fb81d
commit
57e9f2a8db
@ -63,12 +63,22 @@ def get_tensor_stats(
|
||||
"rms" -> square before summing, we'll take sqrt later
|
||||
"value" -> just sum x itself
|
||||
"max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing
|
||||
"rms-sort" -> this is a bit different than the others, it's based on computing the
|
||||
rms over the specified dim and returning percentiles of the result (11 of them).
|
||||
Returns:
|
||||
stats: a Tensor of shape (x.shape[dim],).
|
||||
count: an integer saying how many items were counted in each element
|
||||
of stats.
|
||||
"""
|
||||
|
||||
if stats_type == "rms-sort":
|
||||
rms = (x**2).mean(dim=dim).sqrt()
|
||||
rms = rms.flatten()
|
||||
rms = rms.sort()[0]
|
||||
rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)]
|
||||
count = 1.0
|
||||
return rms, count
|
||||
|
||||
count = x.numel() // x.shape[dim]
|
||||
|
||||
if stats_type == "eigs":
|
||||
@ -164,7 +174,17 @@ class TensorDiagnostic(object):
|
||||
for dim in range(ndim):
|
||||
this_dim_stats = self.stats[dim]
|
||||
if ndim > 1:
|
||||
stats_types = ["abs", "max", "min", "positive", "value", "rms"]
|
||||
# rms-sort is different from the others, it's based on summing over just this
|
||||
# dim, then sorting and returning the percentiles.
|
||||
stats_types = [
|
||||
"abs",
|
||||
"max",
|
||||
"min",
|
||||
"positive",
|
||||
"value",
|
||||
"rms",
|
||||
"rms-sort",
|
||||
]
|
||||
if x.shape[dim] <= self.opts.max_eig_dim:
|
||||
stats_types.append("eigs")
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user