mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Draft of new diagnostics for activations
This commit is contained in:
parent
c75c2dc91d
commit
b7cad258bb
@ -105,7 +105,7 @@ class TensorAndCount:
|
||||
|
||||
class TensorDiagnostic(object):
|
||||
"""This class is not directly used by the user, it is responsible for
|
||||
collecting diagnostics for a single parameter tensor of a torch.nn.Module.
|
||||
collecting diagnostics for a module or parameter tensor of a torch.nn.Module.
|
||||
|
||||
Args:
|
||||
opts:
|
||||
@ -120,7 +120,14 @@ class TensorDiagnostic(object):
|
||||
self.name = name
|
||||
self.class_name = None # will assign in accumulate()
|
||||
|
||||
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
|
||||
self.stats = None # we'll later assign a list to self.stats.
|
||||
# It's a list of dicts, indexed by dim (i.e. by the
|
||||
# axis of the tensor). The dicts, in turn, are
|
||||
# indexed by `stats-type` which are strings in
|
||||
# ["abs", "max", "min", "positive", "value", "rms"].
|
||||
|
||||
# scalar_stats contains some analysis of the activations and gradients,
|
||||
self.scalar_stats = None
|
||||
|
||||
# the keys into self.stats[dim] are strings, whose values can be
|
||||
# "abs", "max", "min" ,"value", "positive", "rms", "value".
|
||||
@ -288,6 +295,168 @@ class TensorDiagnostic(object):
|
||||
)
|
||||
|
||||
|
||||
class ScalarDiagnostic(object):
|
||||
"""This class is not directly used by the user, it is responsible for
|
||||
collecting diagnostics for a single module (subclass of torch.nn.Module) that
|
||||
represents some kind of nonlinearity, e.g. ReLU, sigmoid, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, opts: TensorDiagnosticOptions, name: str):
|
||||
self.opts = opts
|
||||
self.name = name
|
||||
self.class_name = None # will assign in accumulate()
|
||||
self.is_forward_pass = True
|
||||
|
||||
self.tick_scale = None
|
||||
|
||||
self.saved_inputs = []
|
||||
self.is_ok = True
|
||||
|
||||
self.counts = None
|
||||
self.sum_grad = None
|
||||
self.sum_gradsq = None
|
||||
self.sum_abs_grad = None
|
||||
|
||||
|
||||
def accumulate_input(self, x: Tensor, class_name: Optional[str] = None):
|
||||
"""
|
||||
Called in forward pass.
|
||||
"""
|
||||
if not self.is_forward_pass:
|
||||
# in case we did a forward pass without a backward pass, for some reason.
|
||||
self.saved_inputs = []
|
||||
self.is_forward_pass = True
|
||||
|
||||
if class_name is not None:
|
||||
self.class_name = class_name
|
||||
if not self.is_ok:
|
||||
return
|
||||
|
||||
limit = 10
|
||||
if len(self.saved_inputs) > limit:
|
||||
print(f"ERROR: forward pass called for this module over {limit} times with no backward pass. "
|
||||
f" Will not accumulate scalar stats.")
|
||||
self.is_ok = False
|
||||
return
|
||||
self.saved_inputs.append(x)
|
||||
|
||||
def accumulate_output_grad(self, grad: Tensor):
|
||||
if not self.is_ok:
|
||||
return
|
||||
if self.is_forward_pass:
|
||||
self.is_forward_pass = False
|
||||
|
||||
last_shape = 'n/a' if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape
|
||||
if len(self.saved_inputs) == 0 or grad.shape != last_shape:
|
||||
print(f"ERROR: shape mismatch or no forward activation present when backward "
|
||||
f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}"
|
||||
f", shape-of-last-saved-input={last_shape}")
|
||||
self.is_ok = False
|
||||
return
|
||||
|
||||
x = self.saved_inputs.pop()
|
||||
self.process_input_and_grad(x, grad)
|
||||
|
||||
def process_input_and_grad(self, x: Tensor, grad: Tensor):
|
||||
assert x.shape == grad.shape
|
||||
x = x.flatten()
|
||||
grad = grad.flatten()
|
||||
|
||||
num_ticks_per_side = 256
|
||||
|
||||
if self.tick_scale is None:
|
||||
x_abs_sorted = x.abs().sort()[0]
|
||||
# take the 98th percentile as the largest value we count separately.
|
||||
index = int(x.numel() * 0.98)
|
||||
self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side)
|
||||
|
||||
# integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1]
|
||||
self.counts = torch.zeros(2 * num_ticks_per_side, dtype=torch.long, device=x.device)
|
||||
self.sum_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
|
||||
# sum_gradsq is for getting error bars.
|
||||
self.sum_gradsq = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
|
||||
self.sum_abs_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
|
||||
print("tick scale:", self.tick_scale)
|
||||
|
||||
# this will round down.
|
||||
x = (x / self.tick_scale).to(torch.long)
|
||||
x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1)
|
||||
x = x + num_ticks_per_side
|
||||
print("x indexes: ", x)
|
||||
|
||||
self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x))
|
||||
self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double))
|
||||
self.sum_gradsq.index_add_(dim=0, index=x, source=(grad*grad).to(torch.double))
|
||||
self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double))
|
||||
|
||||
|
||||
def print_diagnostics(self):
|
||||
"""Print diagnostics."""
|
||||
if self.is_ok is False or self.counts is None:
|
||||
print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}")
|
||||
return
|
||||
|
||||
counts = self.counts.to('cpu')
|
||||
sum_grad = self.sum_grad.to(device='cpu', dtype=torch.float32)
|
||||
sum_gradsq = self.sum_gradsq.to(device='cpu', dtype=torch.float32)
|
||||
sum_abs_grad = self.sum_abs_grad.to(device='cpu', dtype=torch.float32)
|
||||
|
||||
counts_cumsum = self.counts.cumsum(dim=0)
|
||||
counts_tot = counts_cumsum[-1]
|
||||
|
||||
# subdivide the distribution up into `num_bins` intervals for analysis, for greater
|
||||
# statistical significance. each bin corresponds to multiple of the original 'tick' intervals.
|
||||
num_bins = 20
|
||||
|
||||
# integer division
|
||||
counts_per_bin = (counts_tot // num_bins) + 1
|
||||
bin_indexes = counts_cumsum // counts_per_bin
|
||||
bin_indexes = bin_indexes.clamp(min=0, max=num_bins).to(torch.long)
|
||||
|
||||
bin_counts = torch.zeros(num_bins, dtype=torch.long)
|
||||
bin_counts.index_add_(dim=0, index=bin_indexes, source=counts)
|
||||
bin_grad = torch.zeros(num_bins)
|
||||
bin_grad.index_add_(dim=0, index=bin_indexes, source=sum_grad)
|
||||
bin_gradsq = torch.zeros(num_bins)
|
||||
bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq)
|
||||
bin_abs_grad = torch.zeros(num_bins)
|
||||
bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_gradsq)
|
||||
|
||||
avg_grad = (bin_grad / bin_counts)
|
||||
avg_grad_stddev = (bin_gradsq / bin_counts).sqrt()
|
||||
|
||||
|
||||
bin_boundary_counts = torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin
|
||||
bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts)
|
||||
# boundaries are the "x" values between the bins, e.g. corresponding to the
|
||||
# locations of percentiles of the distribution.
|
||||
num_ticks_per_side = counts.numel() // 2
|
||||
bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale
|
||||
|
||||
|
||||
bin_grad = bin_grad / (bin_counts + 1)
|
||||
bin_conf_interval = bin_gradsq.sqrt() / (bin_counts + 1) # consider this a standard deviation.
|
||||
# bin_grad / bin_abs_grad will give us a sense for how important in a practical sense,
|
||||
# the gradients are.
|
||||
bin_abs_grad = bin_abs_grad / (bin_counts + 1)
|
||||
|
||||
bin_rel_grad = bin_grad / (bin_abs_grad + 1.0e-20)
|
||||
bin_conf = bin_grad / (bin_conf_interval + 1.0e-20)
|
||||
|
||||
def tensor_to_str(x: Tensor):
|
||||
x = ["%.2g" % f for f in x]
|
||||
x = "[" + " ".join(x) + "]"
|
||||
return x
|
||||
|
||||
|
||||
maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
|
||||
|
||||
print(
|
||||
f"module={self.name},{maybe_class_name} bin-boundaries={tensor_to_str(bin_boundaries)}, "
|
||||
f"rel_grad={tensor_to_str(bin_rel_grad)}, grad_conf={tensor_to_str(bin_conf)}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
class ModelDiagnostic(object):
|
||||
"""This class stores diagnostics for all tensors in the torch.nn.Module.
|
||||
@ -306,9 +475,11 @@ class ModelDiagnostic(object):
|
||||
self.opts = opts
|
||||
self.diagnostics = dict()
|
||||
|
||||
|
||||
def __getitem__(self, name: str):
|
||||
T = ScalarDiagnostic if name[-7:] == '.scalar' else TensorDiagnostic
|
||||
if name not in self.diagnostics:
|
||||
self.diagnostics[name] = TensorDiagnostic(self.opts, name)
|
||||
self.diagnostics[name] = T(self.opts, name)
|
||||
return self.diagnostics[name]
|
||||
|
||||
def print_diagnostics(self):
|
||||
@ -343,7 +514,7 @@ def attach_diagnostics(
|
||||
|
||||
# Setting model_diagnostic=ans and n=name below, instead of trying to
|
||||
# capture the variables, ensures that we use the current values.
|
||||
# (matters for name, since the variable gets overwritten).
|
||||
# (this matters for `name`, since the variable gets overwritten).
|
||||
# These closures don't really capture by value, only by
|
||||
# "the final value the variable got in the function" :-(
|
||||
def forward_hook(
|
||||
@ -373,8 +544,32 @@ def attach_diagnostics(
|
||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
|
||||
class_name=type(_module).__name__)
|
||||
|
||||
module.register_forward_hook(forward_hook)
|
||||
module.register_backward_hook(backward_hook)
|
||||
if type(module).__name__ in ["Sigmoid", "Tanh", "ReLU", "TanSwish", "Swish", "DoubleSwish"]:
|
||||
# For these specific module types, accumulate some additional diagnostics
|
||||
# that can help us improve the activation function. These require a lot of memory,
|
||||
# to save the forward activations, so limit this to some select classes.
|
||||
# Note: this will not work correctly for all model types.
|
||||
def scalar_forward_hook(
|
||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||
):
|
||||
if isinstance(_input, tuple):
|
||||
_input, = _input
|
||||
assert isinstance(_input, Tensor)
|
||||
_model_diagnostic[f"{_name}.scalar"].accumulate_input(_input,
|
||||
class_name=type(_module).__name__)
|
||||
|
||||
def scalar_backward_hook(
|
||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||
):
|
||||
if isinstance(_output, tuple):
|
||||
_output, = _output
|
||||
assert isinstance(_output, Tensor)
|
||||
_model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output)
|
||||
|
||||
module.register_forward_hook(scalar_forward_hook)
|
||||
module.register_backward_hook(scalar_backward_hook)
|
||||
|
||||
|
||||
|
||||
for name, parameter in model.named_parameters():
|
||||
|
||||
@ -399,7 +594,7 @@ def _test_tensor_diagnostic():
|
||||
|
||||
diagnostic.print_diagnostics()
|
||||
|
||||
model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
|
||||
model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 80))
|
||||
|
||||
diagnostic = attach_diagnostics(model, opts)
|
||||
for _ in range(10):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user