mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
do some changes
This commit is contained in:
parent
fb5d677c7f
commit
16dda9672f
@ -18,7 +18,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -29,18 +29,14 @@ class TensorDiagnosticOptions(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_limit:
|
memory_limit:
|
||||||
The maximum number of bytes per tensor
|
The maximum number of bytes per tensor
|
||||||
(limits how many copies of the tensor we cache).
|
(limits how many copies of the tensor we cache).
|
||||||
max_eig_dim:
|
max_eig_dim:
|
||||||
The maximum dimension for which we print out eigenvalues
|
The maximum dimension for which we print out eigenvalues
|
||||||
(limited for speed reasons).
|
(limited for speed reasons).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, memory_limit: int = (2 ** 20), max_eig_dim: int = 512):
|
||||||
self,
|
|
||||||
memory_limit: int = (2 ** 20),
|
|
||||||
max_eig_dim: int = 512
|
|
||||||
):
|
|
||||||
self.memory_limit = memory_limit
|
self.memory_limit = memory_limit
|
||||||
self.max_eig_dim = max_eig_dim
|
self.max_eig_dim = max_eig_dim
|
||||||
|
|
||||||
@ -49,24 +45,29 @@ class TensorDiagnosticOptions(object):
|
|||||||
|
|
||||||
|
|
||||||
def get_tensor_stats(
|
def get_tensor_stats(
|
||||||
x: Tensor, dim: int, stats_type: str
|
x: Tensor,
|
||||||
|
dim: int,
|
||||||
|
stats_type: str,
|
||||||
) -> Tuple[Tensor, int]:
|
) -> Tuple[Tensor, int]:
|
||||||
"""
|
"""
|
||||||
Returns the specified transformation of the Tensor (either x or x.abs()
|
Returns the specified transformation of the Tensor (either x or x.abs()
|
||||||
or (x > 0), summed over all but the index `dim`.
|
or (x > 0), summed over all but the index `dim`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Tensor, tensor to be analyzed
|
x:
|
||||||
dim: dimension with 0 <= dim < x.ndim
|
Tensor, tensor to be analyzed
|
||||||
|
dim:
|
||||||
|
Dimension with 0 <= dim < x.ndim
|
||||||
stats_type:
|
stats_type:
|
||||||
"abs" -> take abs() before summing
|
The stats_type includes several types:
|
||||||
"positive" -> take (x > 0) before summing
|
"abs" -> take abs() before summing
|
||||||
"rms" -> square before summing, we'll take sqrt later
|
"positive" -> take (x > 0) before summing
|
||||||
"value -> just sum x itself
|
"rms" -> square before summing, we'll take sqrt later
|
||||||
Returns (stats, count)
|
"value -> just sum x itself
|
||||||
where stats is a Tensor of shape (x.shape[dim],), and the count
|
Returns:
|
||||||
is an integer saying how many items were counted in each element
|
stats: a Tensor of shape (x.shape[dim],).
|
||||||
of stats.
|
count: an integer saying how many items were counted in each element
|
||||||
|
of stats.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
count = x.numel() // x.shape[dim]
|
count = x.numel() // x.shape[dim]
|
||||||
@ -86,7 +87,7 @@ def get_tensor_stats(
|
|||||||
else:
|
else:
|
||||||
assert stats_type == "value"
|
assert stats_type == "value"
|
||||||
|
|
||||||
sum_dims = [ d for d in range(x.ndim) if d != dim ]
|
sum_dims = [d for d in range(x.ndim) if d != dim]
|
||||||
if len(sum_dims) > 0:
|
if len(sum_dims) > 0:
|
||||||
x = torch.sum(x, dim=sum_dims)
|
x = torch.sum(x, dim=sum_dims)
|
||||||
x = x.flatten()
|
x = x.flatten()
|
||||||
@ -102,46 +103,49 @@ def get_diagnostics_for_dim(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
This function gets diagnostics for a dimension of a module.
|
This function gets diagnostics for a dimension of a module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
|
dim:
|
||||||
options: options object
|
the dimension to analyze, with 0 <= dim < tensors[0].ndim
|
||||||
sizes_same: true if all the tensor sizes are the same on this dimension
|
options:
|
||||||
stats_type: either "abs" or "positive" or "eigs" or "value",
|
options object
|
||||||
imdictates the type of stats
|
sizes_same:
|
||||||
we accumulate, abs is mean absolute value, "positive"
|
True if all the tensor sizes are the same on this dimension
|
||||||
is proportion of positive to nonnegative values, "eigs"
|
stats_type: either "abs" or "positive" or "eigs" or "value",
|
||||||
is eigenvalues after doing outer product on this dim, sum
|
imdictates the type of stats we accumulate, abs is mean absolute
|
||||||
over all other dimes.
|
value, "positive" is proportion of positive to nonnegative values,
|
||||||
|
"eigs" is eigenvalues after doing outer product on this dim, sum
|
||||||
|
over all other dimes.
|
||||||
Returns:
|
Returns:
|
||||||
Diagnostic as a string, either percentiles or the actual values,
|
Diagnostic as a string, either percentiles or the actual values,
|
||||||
see the code. Will return the empty string if the diagnostics did
|
see the code. Will return the empty string if the diagnostics did
|
||||||
not make sense to print out for this dimension, e.g. dimension
|
not make sense to print out for this dimension, e.g. dimension
|
||||||
mismatch and stats_type == "eigs"
|
mismatch and stats_type == "eigs".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# stats_and_counts is a list of pair (Tensor, int)
|
# stats_and_counts is a list of pair (Tensor, int)
|
||||||
stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ]
|
stats_and_counts = [get_tensor_stats(x, dim, stats_type) for x in tensors]
|
||||||
stats = [ x[0] for x in stats_and_counts ]
|
stats = [x[0] for x in stats_and_counts]
|
||||||
counts = [ x[1] for x in stats_and_counts ]
|
counts = [x[1] for x in stats_and_counts]
|
||||||
|
|
||||||
if stats_type == "eigs":
|
if stats_type == "eigs":
|
||||||
try:
|
try:
|
||||||
stats = torch.stack(stats).sum(dim=0)
|
stats = torch.stack(stats).sum(dim=0)
|
||||||
except:
|
except: # noqa
|
||||||
return ''
|
return ""
|
||||||
count = sum(counts)
|
count = sum(counts)
|
||||||
stats = stats / count
|
stats = stats / count
|
||||||
stats, _ = torch.symeig(stats)
|
stats, _ = torch.symeig(stats)
|
||||||
stats = stats.abs().sqrt()
|
stats = stats.abs().sqrt()
|
||||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||||
elif sizes_same:
|
elif sizes_same:
|
||||||
stats = torch.stack(stats).sum(dim=0)
|
stats = torch.stack(stats).sum(dim=0)
|
||||||
count = sum(counts)
|
count = sum(counts)
|
||||||
stats = stats / count
|
stats = stats / count
|
||||||
else:
|
else:
|
||||||
stats = [ x[0] / x[1] for x in stats_and_counts ]
|
stats = [x[0] / x[1] for x in stats_and_counts]
|
||||||
stats = torch.cat(stats, dim=0)
|
stats = torch.cat(stats, dim=0)
|
||||||
if stats_type == 'rms':
|
if stats_type == "rms":
|
||||||
stats = stats.sqrt()
|
stats = stats.sqrt()
|
||||||
|
|
||||||
# if `summarize` we print percentiles of the stats; else,
|
# if `summarize` we print percentiles of the stats; else,
|
||||||
@ -156,13 +160,13 @@ def get_diagnostics_for_dim(
|
|||||||
for i in range(num_percentiles + 1):
|
for i in range(num_percentiles + 1):
|
||||||
index = (i * (size - 1)) // num_percentiles
|
index = (i * (size - 1)) // num_percentiles
|
||||||
percentiles.append(stats[index].item())
|
percentiles.append(stats[index].item())
|
||||||
percentiles = [ '%.2g' % x for x in percentiles ]
|
percentiles = ["%.2g" % x for x in percentiles]
|
||||||
percentiles = ' '.join(percentiles)
|
percentiles = " ".join(percentiles)
|
||||||
ans = f'percentiles: [{percentiles}]'
|
ans = f"percentiles: [{percentiles}]"
|
||||||
else:
|
else:
|
||||||
ans = stats.tolist()
|
ans = stats.tolist()
|
||||||
ans = [ '%.2g' % x for x in ans ]
|
ans = ["%.2g" % x for x in ans]
|
||||||
ans = '[' + ' '.join(ans) + ']'
|
ans = "[" + " ".join(ans) + "]"
|
||||||
if stats_type == "value":
|
if stats_type == "value":
|
||||||
# This norm is useful because it is strictly less than the largest
|
# This norm is useful because it is strictly less than the largest
|
||||||
# sqrt(eigenvalue) of the variance, which we print out, and shows,
|
# sqrt(eigenvalue) of the variance, which we print out, and shows,
|
||||||
@ -171,11 +175,11 @@ def get_diagnostics_for_dim(
|
|||||||
norm = (stats ** 2).sum().sqrt().item()
|
norm = (stats ** 2).sum().sqrt().item()
|
||||||
mean = stats.mean().item()
|
mean = stats.mean().item()
|
||||||
rms = (stats ** 2).mean().sqrt().item()
|
rms = (stats ** 2).mean().sqrt().item()
|
||||||
ans += f', norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}'
|
ans += f", norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}"
|
||||||
else:
|
else:
|
||||||
mean = stats.mean().item()
|
mean = stats.mean().item()
|
||||||
rms = (stats ** 2).mean().sqrt().item()
|
rms = (stats ** 2).mean().sqrt().item()
|
||||||
ans += f', mean={mean:.2g}, rms={rms:.2g}'
|
ans += f", mean={mean:.2g}, rms={rms:.2g}"
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@ -201,15 +205,15 @@ def print_diagnostics_for_dim(
|
|||||||
if tensors[0].shape[dim] <= options.max_eig_dim:
|
if tensors[0].shape[dim] <= options.max_eig_dim:
|
||||||
stats_types.append("eigs")
|
stats_types.append("eigs")
|
||||||
else:
|
else:
|
||||||
stats_types = [ "value", "abs" ]
|
stats_types = ["value", "abs"]
|
||||||
|
|
||||||
for stats_type in stats_types:
|
for stats_type in stats_types:
|
||||||
sizes = [ x.shape[dim] for x in tensors ]
|
sizes = [x.shape[dim] for x in tensors]
|
||||||
sizes_same = all([ x == sizes[0] for x in sizes ])
|
sizes_same = all([x == sizes[0] for x in sizes])
|
||||||
s = get_diagnostics_for_dim(dim, tensors,
|
s = get_diagnostics_for_dim(
|
||||||
options, sizes_same,
|
dim, tensors, options, sizes_same, stats_type
|
||||||
stats_type)
|
)
|
||||||
if s == '':
|
if s == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
min_size = min(sizes)
|
min_size = min(sizes)
|
||||||
@ -279,16 +283,13 @@ class TensorDiagnostic(object):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
torch.ones(1, 1, device)
|
except: # noqa
|
||||||
except:
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
|
||||||
ndim = self.saved_tensors[0].ndim
|
ndim = self.saved_tensors[0].ndim
|
||||||
tensors = [x.to(device) for x in self.saved_tensors]
|
tensors = [x.to(device) for x in self.saved_tensors]
|
||||||
for dim in range(ndim):
|
for dim in range(ndim):
|
||||||
print_diagnostics_for_dim(
|
print_diagnostics_for_dim(self.name, dim, tensors, self.opts)
|
||||||
self.name, dim, tensors, self.opts
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDiagnostic(object):
|
class ModelDiagnostic(object):
|
||||||
@ -299,11 +300,14 @@ class ModelDiagnostic(object):
|
|||||||
Options object.
|
Options object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, opts: TensorDiagnosticOptions = TensorDiagnosticOptions()):
|
def __init__(self, opts: Optional[TensorDiagnosticOptions] = None):
|
||||||
# In this dictionary, the keys are tensors names and the values
|
# In this dictionary, the keys are tensors names and the values
|
||||||
# are corresponding TensorDiagnostic objects.
|
# are corresponding TensorDiagnostic objects.
|
||||||
|
if opts is None:
|
||||||
|
self.opts = TensorDiagnosticOptions()
|
||||||
|
else:
|
||||||
|
self.opts = opts
|
||||||
self.diagnostics = dict()
|
self.diagnostics = dict()
|
||||||
self.opts = opts
|
|
||||||
|
|
||||||
def __getitem__(self, name: str):
|
def __getitem__(self, name: str):
|
||||||
if name not in self.diagnostics:
|
if name not in self.diagnostics:
|
||||||
@ -380,7 +384,7 @@ def attach_diagnostics(
|
|||||||
|
|
||||||
|
|
||||||
def _test_tensor_diagnostic():
|
def _test_tensor_diagnostic():
|
||||||
opts = TensorDiagnosticOptions(2**20, 512)
|
opts = TensorDiagnosticOptions(2 ** 20, 512)
|
||||||
|
|
||||||
diagnostic = TensorDiagnostic(opts, "foo")
|
diagnostic = TensorDiagnostic(opts, "foo")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user