mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Add the "rms-sort" diagnostics (#1851)
This commit is contained in:
parent
ad966fb81d
commit
57e9f2a8db
@ -63,12 +63,22 @@ def get_tensor_stats(
|
|||||||
"rms" -> square before summing, we'll take sqrt later
|
"rms" -> square before summing, we'll take sqrt later
|
||||||
"value" -> just sum x itself
|
"value" -> just sum x itself
|
||||||
"max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing
|
"max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing
|
||||||
|
"rms-sort" -> this is a bit different than the others, it's based on computing the
|
||||||
|
rms over the specified dim and returning percentiles of the result (11 of them).
|
||||||
Returns:
|
Returns:
|
||||||
stats: a Tensor of shape (x.shape[dim],).
|
stats: a Tensor of shape (x.shape[dim],).
|
||||||
count: an integer saying how many items were counted in each element
|
count: an integer saying how many items were counted in each element
|
||||||
of stats.
|
of stats.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if stats_type == "rms-sort":
|
||||||
|
rms = (x**2).mean(dim=dim).sqrt()
|
||||||
|
rms = rms.flatten()
|
||||||
|
rms = rms.sort()[0]
|
||||||
|
rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)]
|
||||||
|
count = 1.0
|
||||||
|
return rms, count
|
||||||
|
|
||||||
count = x.numel() // x.shape[dim]
|
count = x.numel() // x.shape[dim]
|
||||||
|
|
||||||
if stats_type == "eigs":
|
if stats_type == "eigs":
|
||||||
@ -164,7 +174,17 @@ class TensorDiagnostic(object):
|
|||||||
for dim in range(ndim):
|
for dim in range(ndim):
|
||||||
this_dim_stats = self.stats[dim]
|
this_dim_stats = self.stats[dim]
|
||||||
if ndim > 1:
|
if ndim > 1:
|
||||||
stats_types = ["abs", "max", "min", "positive", "value", "rms"]
|
# rms-sort is different from the others, it's based on summing over just this
|
||||||
|
# dim, then sorting and returning the percentiles.
|
||||||
|
stats_types = [
|
||||||
|
"abs",
|
||||||
|
"max",
|
||||||
|
"min",
|
||||||
|
"positive",
|
||||||
|
"value",
|
||||||
|
"rms",
|
||||||
|
"rms-sort",
|
||||||
|
]
|
||||||
if x.shape[dim] <= self.opts.max_eig_dim:
|
if x.shape[dim] <= self.opts.max_eig_dim:
|
||||||
stats_types.append("eigs")
|
stats_types.append("eigs")
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user