do some changes

This commit is contained in:
luomingshuang 2022-03-15 20:31:53 +08:00
parent fb5d677c7f
commit 16dda9672f

View File

@ -18,7 +18,7 @@
import random import random
from typing import List, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -29,18 +29,14 @@ class TensorDiagnosticOptions(object):
Args: Args:
memory_limit: memory_limit:
The maximum number of bytes per tensor The maximum number of bytes per tensor
(limits how many copies of the tensor we cache). (limits how many copies of the tensor we cache).
max_eig_dim: max_eig_dim:
The maximum dimension for which we print out eigenvalues The maximum dimension for which we print out eigenvalues
(limited for speed reasons). (limited for speed reasons).
""" """
def __init__( def __init__(self, memory_limit: int = (2 ** 20), max_eig_dim: int = 512):
self,
memory_limit: int = (2 ** 20),
max_eig_dim: int = 512
):
self.memory_limit = memory_limit self.memory_limit = memory_limit
self.max_eig_dim = max_eig_dim self.max_eig_dim = max_eig_dim
@ -49,24 +45,29 @@ class TensorDiagnosticOptions(object):
def get_tensor_stats( def get_tensor_stats(
x: Tensor, dim: int, stats_type: str x: Tensor,
dim: int,
stats_type: str,
) -> Tuple[Tensor, int]: ) -> Tuple[Tensor, int]:
""" """
Returns the specified transformation of the Tensor (either x or x.abs() Returns the specified transformation of the Tensor (either x or x.abs()
or (x > 0), summed over all but the index `dim`. or (x > 0), summed over all but the index `dim`.
Args: Args:
x: Tensor, tensor to be analyzed x:
dim: dimension with 0 <= dim < x.ndim Tensor, tensor to be analyzed
dim:
Dimension with 0 <= dim < x.ndim
stats_type: stats_type:
"abs" -> take abs() before summing The stats_type includes several types:
"positive" -> take (x > 0) before summing "abs" -> take abs() before summing
"rms" -> square before summing, we'll take sqrt later "positive" -> take (x > 0) before summing
"value -> just sum x itself "rms" -> square before summing, we'll take sqrt later
Returns (stats, count) "value -> just sum x itself
where stats is a Tensor of shape (x.shape[dim],), and the count Returns:
is an integer saying how many items were counted in each element stats: a Tensor of shape (x.shape[dim],).
of stats. count: an integer saying how many items were counted in each element
of stats.
""" """
count = x.numel() // x.shape[dim] count = x.numel() // x.shape[dim]
@ -86,7 +87,7 @@ def get_tensor_stats(
else: else:
assert stats_type == "value" assert stats_type == "value"
sum_dims = [ d for d in range(x.ndim) if d != dim ] sum_dims = [d for d in range(x.ndim) if d != dim]
if len(sum_dims) > 0: if len(sum_dims) > 0:
x = torch.sum(x, dim=sum_dims) x = torch.sum(x, dim=sum_dims)
x = x.flatten() x = x.flatten()
@ -102,46 +103,49 @@ def get_diagnostics_for_dim(
) -> str: ) -> str:
""" """
This function gets diagnostics for a dimension of a module. This function gets diagnostics for a dimension of a module.
Args: Args:
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim dim:
options: options object the dimension to analyze, with 0 <= dim < tensors[0].ndim
sizes_same: true if all the tensor sizes are the same on this dimension options:
stats_type: either "abs" or "positive" or "eigs" or "value", options object
imdictates the type of stats sizes_same:
we accumulate, abs is mean absolute value, "positive" True if all the tensor sizes are the same on this dimension
is proportion of positive to nonnegative values, "eigs" stats_type: either "abs" or "positive" or "eigs" or "value",
is eigenvalues after doing outer product on this dim, sum imdictates the type of stats we accumulate, abs is mean absolute
over all other dimes. value, "positive" is proportion of positive to nonnegative values,
"eigs" is eigenvalues after doing outer product on this dim, sum
over all other dimes.
Returns: Returns:
Diagnostic as a string, either percentiles or the actual values, Diagnostic as a string, either percentiles or the actual values,
see the code. Will return the empty string if the diagnostics did see the code. Will return the empty string if the diagnostics did
not make sense to print out for this dimension, e.g. dimension not make sense to print out for this dimension, e.g. dimension
mismatch and stats_type == "eigs" mismatch and stats_type == "eigs".
""" """
# stats_and_counts is a list of pair (Tensor, int) # stats_and_counts is a list of pair (Tensor, int)
stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ] stats_and_counts = [get_tensor_stats(x, dim, stats_type) for x in tensors]
stats = [ x[0] for x in stats_and_counts ] stats = [x[0] for x in stats_and_counts]
counts = [ x[1] for x in stats_and_counts ] counts = [x[1] for x in stats_and_counts]
if stats_type == "eigs": if stats_type == "eigs":
try: try:
stats = torch.stack(stats).sum(dim=0) stats = torch.stack(stats).sum(dim=0)
except: except: # noqa
return '' return ""
count = sum(counts) count = sum(counts)
stats = stats / count stats = stats / count
stats, _ = torch.symeig(stats) stats, _ = torch.symeig(stats)
stats = stats.abs().sqrt() stats = stats.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance # sqrt so it reflects data magnitude, like stddev- not variance
elif sizes_same: elif sizes_same:
stats = torch.stack(stats).sum(dim=0) stats = torch.stack(stats).sum(dim=0)
count = sum(counts) count = sum(counts)
stats = stats / count stats = stats / count
else: else:
stats = [ x[0] / x[1] for x in stats_and_counts ] stats = [x[0] / x[1] for x in stats_and_counts]
stats = torch.cat(stats, dim=0) stats = torch.cat(stats, dim=0)
if stats_type == 'rms': if stats_type == "rms":
stats = stats.sqrt() stats = stats.sqrt()
# if `summarize` we print percentiles of the stats; else, # if `summarize` we print percentiles of the stats; else,
@ -156,13 +160,13 @@ def get_diagnostics_for_dim(
for i in range(num_percentiles + 1): for i in range(num_percentiles + 1):
index = (i * (size - 1)) // num_percentiles index = (i * (size - 1)) // num_percentiles
percentiles.append(stats[index].item()) percentiles.append(stats[index].item())
percentiles = [ '%.2g' % x for x in percentiles ] percentiles = ["%.2g" % x for x in percentiles]
percentiles = ' '.join(percentiles) percentiles = " ".join(percentiles)
ans = f'percentiles: [{percentiles}]' ans = f"percentiles: [{percentiles}]"
else: else:
ans = stats.tolist() ans = stats.tolist()
ans = [ '%.2g' % x for x in ans ] ans = ["%.2g" % x for x in ans]
ans = '[' + ' '.join(ans) + ']' ans = "[" + " ".join(ans) + "]"
if stats_type == "value": if stats_type == "value":
# This norm is useful because it is strictly less than the largest # This norm is useful because it is strictly less than the largest
# sqrt(eigenvalue) of the variance, which we print out, and shows, # sqrt(eigenvalue) of the variance, which we print out, and shows,
@ -171,11 +175,11 @@ def get_diagnostics_for_dim(
norm = (stats ** 2).sum().sqrt().item() norm = (stats ** 2).sum().sqrt().item()
mean = stats.mean().item() mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item() rms = (stats ** 2).mean().sqrt().item()
ans += f', norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}' ans += f", norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}"
else: else:
mean = stats.mean().item() mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item() rms = (stats ** 2).mean().sqrt().item()
ans += f', mean={mean:.2g}, rms={rms:.2g}' ans += f", mean={mean:.2g}, rms={rms:.2g}"
return ans return ans
@ -201,15 +205,15 @@ def print_diagnostics_for_dim(
if tensors[0].shape[dim] <= options.max_eig_dim: if tensors[0].shape[dim] <= options.max_eig_dim:
stats_types.append("eigs") stats_types.append("eigs")
else: else:
stats_types = [ "value", "abs" ] stats_types = ["value", "abs"]
for stats_type in stats_types: for stats_type in stats_types:
sizes = [ x.shape[dim] for x in tensors ] sizes = [x.shape[dim] for x in tensors]
sizes_same = all([ x == sizes[0] for x in sizes ]) sizes_same = all([x == sizes[0] for x in sizes])
s = get_diagnostics_for_dim(dim, tensors, s = get_diagnostics_for_dim(
options, sizes_same, dim, tensors, options, sizes_same, stats_type
stats_type) )
if s == '': if s == "":
continue continue
min_size = min(sizes) min_size = min(sizes)
@ -279,16 +283,13 @@ class TensorDiagnostic(object):
try: try:
device = torch.device("cuda") device = torch.device("cuda")
torch.ones(1, 1, device) except: # noqa
except:
device = torch.device("cpu") device = torch.device("cpu")
ndim = self.saved_tensors[0].ndim ndim = self.saved_tensors[0].ndim
tensors = [x.to(device) for x in self.saved_tensors] tensors = [x.to(device) for x in self.saved_tensors]
for dim in range(ndim): for dim in range(ndim):
print_diagnostics_for_dim( print_diagnostics_for_dim(self.name, dim, tensors, self.opts)
self.name, dim, tensors, self.opts
)
class ModelDiagnostic(object): class ModelDiagnostic(object):
@ -299,11 +300,14 @@ class ModelDiagnostic(object):
Options object. Options object.
""" """
def __init__(self, opts: TensorDiagnosticOptions = TensorDiagnosticOptions()): def __init__(self, opts: Optional[TensorDiagnosticOptions] = None):
# In this dictionary, the keys are tensors names and the values # In this dictionary, the keys are tensors names and the values
# are corresponding TensorDiagnostic objects. # are corresponding TensorDiagnostic objects.
if opts is None:
self.opts = TensorDiagnosticOptions()
else:
self.opts = opts
self.diagnostics = dict() self.diagnostics = dict()
self.opts = opts
def __getitem__(self, name: str): def __getitem__(self, name: str):
if name not in self.diagnostics: if name not in self.diagnostics:
@ -380,7 +384,7 @@ def attach_diagnostics(
def _test_tensor_diagnostic(): def _test_tensor_diagnostic():
opts = TensorDiagnosticOptions(2**20, 512) opts = TensorDiagnosticOptions(2 ** 20, 512)
diagnostic = TensorDiagnostic(opts, "foo") diagnostic = TensorDiagnostic(opts, "foo")