Merge changes to diagnostics

This commit is contained in:
Daniel Povey 2022-03-10 10:31:42 +08:00
commit feb20ca84d

View File

@ -11,24 +11,21 @@ class TensorDiagnosticOptions(object):
Options object for tensor diagnostics: Options object for tensor diagnostics:
Args: Args:
memory_limit: the maximum number of bytes per tensor (limits how many copies memory_limit: the maximum number of bytes we store per tensor (limits how many copies
of the tensor we cache). of the tensor we cache).
max_eig_dim: the maximum dimension for which we print out eigenvalues
(limited for speed reasons).
""" """
def __init__(self, memory_limit: int, def __init__(self,
print_pos_ratio: bool = True): memory_limit: int = (2 ** 20),
max_eig_dim: int = 512):
self.memory_limit = memory_limit self.memory_limit = memory_limit
self.print_pos_ratio = print_pos_ratio self.max_eig_dim = max_eig_dim
def dim_is_summarized(self, size: int): def dim_is_summarized(self, size: int):
return size > 10 and size != 31 return size > 10 and size != 31
def stats_types(self):
if self.print_pos_ratio:
return ["mean-abs", "pos-ratio"]
else:
return ["mean-abs"]
def get_tensor_stats(x: Tensor, dim: int, def get_tensor_stats(x: Tensor, dim: int,
@ -41,25 +38,34 @@ def get_tensor_stats(x: Tensor, dim: int,
x: Tensor, tensor to be analyzed x: Tensor, tensor to be analyzed
dim: dimension with 0 <= dim < x.ndim dim: dimension with 0 <= dim < x.ndim
stats_type: stats_type:
"mean-abs" or "abs-value" -> take abs() before summing "abs" -> take abs() before summing
"pos-ratio" -> take (x > 0) before summing "positive" -> take (x > 0) before summing
"rms" -> square before summing, we'll take sqrt later
"value -> just sum x itself "value -> just sum x itself
Returns (stats, count) Returns (stats, count)
where stats is a Tensor of shape (x.shape[dim],), and the count where stats is a Tensor of shape (x.shape[dim],), and the count
is an integer saying how many items were counted in each element is an integer saying how many items were counted in each element
of stats. of stats.
""" """
if stats_type == "mean-abs" or stats_type == "abs-value": count = x.numel() // x.shape[dim]
if stats_type == "eigs":
x = x.transpose(dim, -1)
x = x.reshape(-1, x.shape[-1])
# shape of returned tensor: (s, s) where s is size of dimension `dim` of original x.
return torch.matmul(x.transpose(0, 1), x), count
elif stats_type == "abs":
x = x.abs() x = x.abs()
elif stats_type == "pos-ratio": elif stats_type == "rms":
x = x ** 2
elif stats_type == "positive":
x = (x > 0).to(dtype=torch.float) x = (x > 0).to(dtype=torch.float)
else: else:
assert stats_type == "value" assert stats_type == "value"
orig_numel = x.numel()
sum_dims = [ d for d in range(x.ndim) if d != dim ] sum_dims = [ d for d in range(x.ndim) if d != dim ]
if len(sum_dims) > 0: if len(sum_dims) > 0:
x = torch.sum(x, dim=sum_dims) x = torch.sum(x, dim=sum_dims)
count = orig_numel // x.numel()
x = x.flatten() x = x.flatten()
return x, count return x, count
@ -73,24 +79,42 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
options: options object options: options object
sizes_same: true if all the tensor sizes are the same on this dimension sizes_same: true if all the tensor sizes are the same on this dimension
stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats stats_type: either "abs" or "positive" or "eigs" or "value,
we accumulate, mean-abs is mean absolute value, "pos-ratio" imdictates the type of stats
is proportion of positive to nonnegative values. we accumulate, abs is mean absolute value, "positive"
is proportion of positive to nonnegative values, "eigs"
is eigenvalues after doing outer product on this dim, sum
over all other dimes.
Returns: Returns:
Diagnostic as a string, either percentiles or the actual values, Diagnostic as a string, either percentiles or the actual values,
see the code. see the code. Will return the empty string if the diagnostics did
not make sense to print out for this dimension, e.g. dimension
mismatch and stats_type == "eigs"
""" """
# stats_and_counts is a list of pair (Tensor, int) # stats_and_counts is a list of pair (Tensor, int)
stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ] stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ]
stats = [ x[0] for x in stats_and_counts ] stats = [ x[0] for x in stats_and_counts ]
counts = [ x[1] for x in stats_and_counts ] counts = [ x[1] for x in stats_and_counts ]
if sizes_same:
if stats_type == "eigs":
try:
stats = torch.stack(stats).sum(dim=0)
except:
return ''
count = sum(counts)
stats = stats / count
stats, _ = torch.symeig(stats)
stats = stats.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance
elif sizes_same:
stats = torch.stack(stats).sum(dim=0) stats = torch.stack(stats).sum(dim=0)
count = sum(counts) count = sum(counts)
stats = stats / count stats = stats / count
else: else:
stats = [ x[0] / x[1] for x in stats_and_counts ] stats = [ x[0] / x[1] for x in stats_and_counts ]
stats = torch.cat(stats, dim=0) stats = torch.cat(stats, dim=0)
if stats_type == 'rms':
stats = stats.sqrt()
# if `summarize` we print percentiles of the stats; else, # if `summarize` we print percentiles of the stats; else,
# we print out individual elements. # we print out individual elements.
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
@ -117,9 +141,12 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
options: TensorDiagnosticOptions): options: TensorDiagnosticOptions):
ndim = tensors[0].ndim ndim = tensors[0].ndim
# options.stats_types() should return [ "mean-abs", "pos-ratio" ] in the if ndim > 1:
# normal case. stats_types = ["abs", "positive", "value", "rms"]
stats_types = options.stats_types() if ndim > 1 else [ "value", "abs-value" ] if tensors[0].shape[dim] <= options.max_eig_dim:
stats_types.append("eigs")
else:
stats_types = [ "value", "abs" ]
for stats_type in stats_types: for stats_type in stats_types:
sizes = [ x.shape[dim] for x in tensors ] sizes = [ x.shape[dim] for x in tensors ]
@ -127,11 +154,13 @@ def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
s = get_diagnostics_for_dim(dim, tensors, s = get_diagnostics_for_dim(dim, tensors,
options, sizes_same, options, sizes_same,
stats_type) stats_type)
if s == '':
continue
min_size = min(sizes) min_size = min(sizes)
max_size = max(sizes) max_size = max(sizes)
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}" size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
# stats_type will be "mean-abs" or "pos-ratio". # stats_type will be "abs" or "positive".
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}") print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
@ -181,15 +210,22 @@ class TensorDiagnostic(object):
# ensure there is at least one dim. # ensure there is at least one dim.
self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ] self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ]
try:
device = torch.device('cuda')
torch.ones(1, 1, device)
except:
device = torch.device('cpu')
ndim = self.saved_tensors[0].ndim ndim = self.saved_tensors[0].ndim
tensors = [x.to(device) for x in self.saved_tensors]
for dim in range(ndim): for dim in range(ndim):
print_diagnostics_for_dim(self.name, dim, print_diagnostics_for_dim(self.name, dim,
self.saved_tensors, tensors,
self.opts) self.opts)
class ModelDiagnostic(object): class ModelDiagnostic(object):
def __init__(self, opts: TensorDiagnosticOptions): def __init__(self, opts: TensorDiagnosticOptions = TensorDiagnosticOptions()):
self.diagnostics = dict() self.diagnostics = dict()
self.opts = opts self.opts = opts
@ -252,7 +288,7 @@ def attach_diagnostics(model: nn.Module,
def _test_tensor_diagnostic(): def _test_tensor_diagnostic():
opts = TensorDiagnosticOptions(2**20, True) opts = TensorDiagnosticOptions(2**20, 512)
diagnostic = TensorDiagnostic(opts, "foo") diagnostic = TensorDiagnostic(opts, "foo")