Draft of new diagnostics for activations

This commit is contained in:
Daniel Povey 2022-11-30 15:57:24 +08:00
parent c75c2dc91d
commit b7cad258bb

View File

@ -105,7 +105,7 @@ class TensorAndCount:
class TensorDiagnostic(object):
"""This class is not directly used by the user, it is responsible for
collecting diagnostics for a single parameter tensor of a torch.nn.Module.
collecting diagnostics for a module or parameter tensor of a torch.nn.Module.
Args:
opts:
@ -120,7 +120,14 @@ class TensorDiagnostic(object):
self.name = name
self.class_name = None # will assign in accumulate()
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
self.stats = None # we'll later assign a list to self.stats.
# It's a list of dicts, indexed by dim (i.e. by the
# axis of the tensor). The dicts, in turn, are
# indexed by `stats-type` which are strings in
# ["abs", "max", "min", "positive", "value", "rms"].
# scalar_stats contains some analysis of the activations and gradients,
self.scalar_stats = None
# the keys into self.stats[dim] are strings, whose values can be
# "abs", "max", "min" ,"value", "positive", "rms", "value".
@ -288,6 +295,168 @@ class TensorDiagnostic(object):
)
class ScalarDiagnostic(object):
"""This class is not directly used by the user, it is responsible for
collecting diagnostics for a single module (subclass of torch.nn.Module) that
represents some kind of nonlinearity, e.g. ReLU, sigmoid, etc.
"""
def __init__(self, opts: TensorDiagnosticOptions, name: str):
self.opts = opts
self.name = name
self.class_name = None # will assign in accumulate()
self.is_forward_pass = True
self.tick_scale = None
self.saved_inputs = []
self.is_ok = True
self.counts = None
self.sum_grad = None
self.sum_gradsq = None
self.sum_abs_grad = None
def accumulate_input(self, x: Tensor, class_name: Optional[str] = None):
"""
Called in forward pass.
"""
if not self.is_forward_pass:
# in case we did a forward pass without a backward pass, for some reason.
self.saved_inputs = []
self.is_forward_pass = True
if class_name is not None:
self.class_name = class_name
if not self.is_ok:
return
limit = 10
if len(self.saved_inputs) > limit:
print(f"ERROR: forward pass called for this module over {limit} times with no backward pass. "
f" Will not accumulate scalar stats.")
self.is_ok = False
return
self.saved_inputs.append(x)
def accumulate_output_grad(self, grad: Tensor):
if not self.is_ok:
return
if self.is_forward_pass:
self.is_forward_pass = False
last_shape = 'n/a' if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape
if len(self.saved_inputs) == 0 or grad.shape != last_shape:
print(f"ERROR: shape mismatch or no forward activation present when backward "
f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}"
f", shape-of-last-saved-input={last_shape}")
self.is_ok = False
return
x = self.saved_inputs.pop()
self.process_input_and_grad(x, grad)
def process_input_and_grad(self, x: Tensor, grad: Tensor):
assert x.shape == grad.shape
x = x.flatten()
grad = grad.flatten()
num_ticks_per_side = 256
if self.tick_scale is None:
x_abs_sorted = x.abs().sort()[0]
# take the 98th percentile as the largest value we count separately.
index = int(x.numel() * 0.98)
self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side)
# integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1]
self.counts = torch.zeros(2 * num_ticks_per_side, dtype=torch.long, device=x.device)
self.sum_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
# sum_gradsq is for getting error bars.
self.sum_gradsq = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
self.sum_abs_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
print("tick scale:", self.tick_scale)
# this will round down.
x = (x / self.tick_scale).to(torch.long)
x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1)
x = x + num_ticks_per_side
print("x indexes: ", x)
self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x))
self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double))
self.sum_gradsq.index_add_(dim=0, index=x, source=(grad*grad).to(torch.double))
self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double))
def print_diagnostics(self):
"""Print diagnostics."""
if self.is_ok is False or self.counts is None:
print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}")
return
counts = self.counts.to('cpu')
sum_grad = self.sum_grad.to(device='cpu', dtype=torch.float32)
sum_gradsq = self.sum_gradsq.to(device='cpu', dtype=torch.float32)
sum_abs_grad = self.sum_abs_grad.to(device='cpu', dtype=torch.float32)
counts_cumsum = self.counts.cumsum(dim=0)
counts_tot = counts_cumsum[-1]
# subdivide the distribution up into `num_bins` intervals for analysis, for greater
# statistical significance. each bin corresponds to multiple of the original 'tick' intervals.
num_bins = 20
# integer division
counts_per_bin = (counts_tot // num_bins) + 1
bin_indexes = counts_cumsum // counts_per_bin
bin_indexes = bin_indexes.clamp(min=0, max=num_bins).to(torch.long)
bin_counts = torch.zeros(num_bins, dtype=torch.long)
bin_counts.index_add_(dim=0, index=bin_indexes, source=counts)
bin_grad = torch.zeros(num_bins)
bin_grad.index_add_(dim=0, index=bin_indexes, source=sum_grad)
bin_gradsq = torch.zeros(num_bins)
bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq)
bin_abs_grad = torch.zeros(num_bins)
bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_gradsq)
avg_grad = (bin_grad / bin_counts)
avg_grad_stddev = (bin_gradsq / bin_counts).sqrt()
bin_boundary_counts = torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin
bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts)
# boundaries are the "x" values between the bins, e.g. corresponding to the
# locations of percentiles of the distribution.
num_ticks_per_side = counts.numel() // 2
bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale
bin_grad = bin_grad / (bin_counts + 1)
bin_conf_interval = bin_gradsq.sqrt() / (bin_counts + 1) # consider this a standard deviation.
# bin_grad / bin_abs_grad will give us a sense for how important in a practical sense,
# the gradients are.
bin_abs_grad = bin_abs_grad / (bin_counts + 1)
bin_rel_grad = bin_grad / (bin_abs_grad + 1.0e-20)
bin_conf = bin_grad / (bin_conf_interval + 1.0e-20)
def tensor_to_str(x: Tensor):
x = ["%.2g" % f for f in x]
x = "[" + " ".join(x) + "]"
return x
maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
print(
f"module={self.name},{maybe_class_name} bin-boundaries={tensor_to_str(bin_boundaries)}, "
f"rel_grad={tensor_to_str(bin_rel_grad)}, grad_conf={tensor_to_str(bin_conf)}"
)
class ModelDiagnostic(object):
"""This class stores diagnostics for all tensors in the torch.nn.Module.
@ -306,9 +475,11 @@ class ModelDiagnostic(object):
self.opts = opts
self.diagnostics = dict()
def __getitem__(self, name: str):
T = ScalarDiagnostic if name[-7:] == '.scalar' else TensorDiagnostic
if name not in self.diagnostics:
self.diagnostics[name] = TensorDiagnostic(self.opts, name)
self.diagnostics[name] = T(self.opts, name)
return self.diagnostics[name]
def print_diagnostics(self):
@ -343,7 +514,7 @@ def attach_diagnostics(
# Setting model_diagnostic=ans and n=name below, instead of trying to
# capture the variables, ensures that we use the current values.
# (matters for name, since the variable gets overwritten).
# (this matters for `name`, since the variable gets overwritten).
# These closures don't really capture by value, only by
# "the final value the variable got in the function" :-(
def forward_hook(
@ -373,8 +544,32 @@ def attach_diagnostics(
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
class_name=type(_module).__name__)
module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)
if type(module).__name__ in ["Sigmoid", "Tanh", "ReLU", "TanSwish", "Swish", "DoubleSwish"]:
# For these specific module types, accumulate some additional diagnostics
# that can help us improve the activation function. These require a lot of memory,
# to save the forward activations, so limit this to some select classes.
# Note: this will not work correctly for all model types.
def scalar_forward_hook(
_module, _input, _output, _model_diagnostic=ans, _name=name
):
if isinstance(_input, tuple):
_input, = _input
assert isinstance(_input, Tensor)
_model_diagnostic[f"{_name}.scalar"].accumulate_input(_input,
class_name=type(_module).__name__)
def scalar_backward_hook(
_module, _input, _output, _model_diagnostic=ans, _name=name
):
if isinstance(_output, tuple):
_output, = _output
assert isinstance(_output, Tensor)
_model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output)
module.register_forward_hook(scalar_forward_hook)
module.register_backward_hook(scalar_backward_hook)
for name, parameter in model.named_parameters():
@ -399,7 +594,7 @@ def _test_tensor_diagnostic():
diagnostic.print_diagnostics()
model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 80))
diagnostic = attach_diagnostics(model, opts)
for _ in range(10):