From b7cad258bb7a52abad022b4a378e800a7fdda3ab Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Nov 2022 15:57:24 +0800 Subject: [PATCH] Draft of new diagnostics for activations --- icefall/diagnostics.py | 209 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 202 insertions(+), 7 deletions(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index a380cdb52..c808f15d4 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -105,7 +105,7 @@ class TensorAndCount: class TensorDiagnostic(object): """This class is not directly used by the user, it is responsible for - collecting diagnostics for a single parameter tensor of a torch.nn.Module. + collecting diagnostics for a module or parameter tensor of a torch.nn.Module. Args: opts: @@ -120,7 +120,14 @@ class TensorDiagnostic(object): self.name = name self.class_name = None # will assign in accumulate() - self.stats = None # we'll later assign a list to this data member. It's a list of dict. + self.stats = None # we'll later assign a list to self.stats. + # It's a list of dicts, indexed by dim (i.e. by the + # axis of the tensor). The dicts, in turn, are + # indexed by `stats-type` which are strings in + # ["abs", "max", "min", "positive", "value", "rms"]. + + # scalar_stats contains some analysis of the activations and gradients, + self.scalar_stats = None # the keys into self.stats[dim] are strings, whose values can be # "abs", "max", "min" ,"value", "positive", "rms", "value". @@ -288,6 +295,168 @@ class TensorDiagnostic(object): ) +class ScalarDiagnostic(object): + """This class is not directly used by the user, it is responsible for + collecting diagnostics for a single module (subclass of torch.nn.Module) that + represents some kind of nonlinearity, e.g. ReLU, sigmoid, etc. + """ + + def __init__(self, opts: TensorDiagnosticOptions, name: str): + self.opts = opts + self.name = name + self.class_name = None # will assign in accumulate() + self.is_forward_pass = True + + self.tick_scale = None + + self.saved_inputs = [] + self.is_ok = True + + self.counts = None + self.sum_grad = None + self.sum_gradsq = None + self.sum_abs_grad = None + + + def accumulate_input(self, x: Tensor, class_name: Optional[str] = None): + """ + Called in forward pass. + """ + if not self.is_forward_pass: + # in case we did a forward pass without a backward pass, for some reason. + self.saved_inputs = [] + self.is_forward_pass = True + + if class_name is not None: + self.class_name = class_name + if not self.is_ok: + return + + limit = 10 + if len(self.saved_inputs) > limit: + print(f"ERROR: forward pass called for this module over {limit} times with no backward pass. " + f" Will not accumulate scalar stats.") + self.is_ok = False + return + self.saved_inputs.append(x) + + def accumulate_output_grad(self, grad: Tensor): + if not self.is_ok: + return + if self.is_forward_pass: + self.is_forward_pass = False + + last_shape = 'n/a' if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape + if len(self.saved_inputs) == 0 or grad.shape != last_shape: + print(f"ERROR: shape mismatch or no forward activation present when backward " + f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}" + f", shape-of-last-saved-input={last_shape}") + self.is_ok = False + return + + x = self.saved_inputs.pop() + self.process_input_and_grad(x, grad) + + def process_input_and_grad(self, x: Tensor, grad: Tensor): + assert x.shape == grad.shape + x = x.flatten() + grad = grad.flatten() + + num_ticks_per_side = 256 + + if self.tick_scale is None: + x_abs_sorted = x.abs().sort()[0] + # take the 98th percentile as the largest value we count separately. + index = int(x.numel() * 0.98) + self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side) + + # integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1] + self.counts = torch.zeros(2 * num_ticks_per_side, dtype=torch.long, device=x.device) + self.sum_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device) + # sum_gradsq is for getting error bars. + self.sum_gradsq = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device) + self.sum_abs_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device) + print("tick scale:", self.tick_scale) + + # this will round down. + x = (x / self.tick_scale).to(torch.long) + x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1) + x = x + num_ticks_per_side + print("x indexes: ", x) + + self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x)) + self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double)) + self.sum_gradsq.index_add_(dim=0, index=x, source=(grad*grad).to(torch.double)) + self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double)) + + + def print_diagnostics(self): + """Print diagnostics.""" + if self.is_ok is False or self.counts is None: + print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}") + return + + counts = self.counts.to('cpu') + sum_grad = self.sum_grad.to(device='cpu', dtype=torch.float32) + sum_gradsq = self.sum_gradsq.to(device='cpu', dtype=torch.float32) + sum_abs_grad = self.sum_abs_grad.to(device='cpu', dtype=torch.float32) + + counts_cumsum = self.counts.cumsum(dim=0) + counts_tot = counts_cumsum[-1] + + # subdivide the distribution up into `num_bins` intervals for analysis, for greater + # statistical significance. each bin corresponds to multiple of the original 'tick' intervals. + num_bins = 20 + + # integer division + counts_per_bin = (counts_tot // num_bins) + 1 + bin_indexes = counts_cumsum // counts_per_bin + bin_indexes = bin_indexes.clamp(min=0, max=num_bins).to(torch.long) + + bin_counts = torch.zeros(num_bins, dtype=torch.long) + bin_counts.index_add_(dim=0, index=bin_indexes, source=counts) + bin_grad = torch.zeros(num_bins) + bin_grad.index_add_(dim=0, index=bin_indexes, source=sum_grad) + bin_gradsq = torch.zeros(num_bins) + bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq) + bin_abs_grad = torch.zeros(num_bins) + bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_gradsq) + + avg_grad = (bin_grad / bin_counts) + avg_grad_stddev = (bin_gradsq / bin_counts).sqrt() + + + bin_boundary_counts = torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin + bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts) + # boundaries are the "x" values between the bins, e.g. corresponding to the + # locations of percentiles of the distribution. + num_ticks_per_side = counts.numel() // 2 + bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale + + + bin_grad = bin_grad / (bin_counts + 1) + bin_conf_interval = bin_gradsq.sqrt() / (bin_counts + 1) # consider this a standard deviation. + # bin_grad / bin_abs_grad will give us a sense for how important in a practical sense, + # the gradients are. + bin_abs_grad = bin_abs_grad / (bin_counts + 1) + + bin_rel_grad = bin_grad / (bin_abs_grad + 1.0e-20) + bin_conf = bin_grad / (bin_conf_interval + 1.0e-20) + + def tensor_to_str(x: Tensor): + x = ["%.2g" % f for f in x] + x = "[" + " ".join(x) + "]" + return x + + + maybe_class_name = f" type={self.class_name}," if self.class_name is not None else "" + + print( + f"module={self.name},{maybe_class_name} bin-boundaries={tensor_to_str(bin_boundaries)}, " + f"rel_grad={tensor_to_str(bin_rel_grad)}, grad_conf={tensor_to_str(bin_conf)}" + ) + + class ModelDiagnostic(object): """This class stores diagnostics for all tensors in the torch.nn.Module. @@ -306,9 +475,11 @@ class ModelDiagnostic(object): self.opts = opts self.diagnostics = dict() + def __getitem__(self, name: str): + T = ScalarDiagnostic if name[-7:] == '.scalar' else TensorDiagnostic if name not in self.diagnostics: - self.diagnostics[name] = TensorDiagnostic(self.opts, name) + self.diagnostics[name] = T(self.opts, name) return self.diagnostics[name] def print_diagnostics(self): @@ -343,7 +514,7 @@ def attach_diagnostics( # Setting model_diagnostic=ans and n=name below, instead of trying to # capture the variables, ensures that we use the current values. - # (matters for name, since the variable gets overwritten). + # (this matters for `name`, since the variable gets overwritten). # These closures don't really capture by value, only by # "the final value the variable got in the function" :-( def forward_hook( @@ -373,8 +544,32 @@ def attach_diagnostics( _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, class_name=type(_module).__name__) - module.register_forward_hook(forward_hook) - module.register_backward_hook(backward_hook) + if type(module).__name__ in ["Sigmoid", "Tanh", "ReLU", "TanSwish", "Swish", "DoubleSwish"]: + # For these specific module types, accumulate some additional diagnostics + # that can help us improve the activation function. These require a lot of memory, + # to save the forward activations, so limit this to some select classes. + # Note: this will not work correctly for all model types. + def scalar_forward_hook( + _module, _input, _output, _model_diagnostic=ans, _name=name + ): + if isinstance(_input, tuple): + _input, = _input + assert isinstance(_input, Tensor) + _model_diagnostic[f"{_name}.scalar"].accumulate_input(_input, + class_name=type(_module).__name__) + + def scalar_backward_hook( + _module, _input, _output, _model_diagnostic=ans, _name=name + ): + if isinstance(_output, tuple): + _output, = _output + assert isinstance(_output, Tensor) + _model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output) + + module.register_forward_hook(scalar_forward_hook) + module.register_backward_hook(scalar_backward_hook) + + for name, parameter in model.named_parameters(): @@ -399,7 +594,7 @@ def _test_tensor_diagnostic(): diagnostic.print_diagnostics() - model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80)) + model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 80)) diagnostic = attach_diagnostics(model, opts) for _ in range(10):