From 87b4619f12598af840d9159436ad279b22041939 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 3 Mar 2022 15:35:12 +0800 Subject: [PATCH] Update docs of arguments, and remove stats_types() function in TensorDiagnosticOptions object. --- .../ASR/transducer_stateless/train.py | 6 +- icefall/diagnostics.py | 258 ++++++++++-------- 2 files changed, 149 insertions(+), 115 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 9049be031..2cc6480d5 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -521,7 +521,6 @@ def train_one_epoch( if params.print_diagnostics and batch_idx == 5: return - if batch_idx % params.log_interval == 0: logging.info( f"Epoch {params.cur_epoch}, " @@ -631,10 +630,11 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions(2**22) # allow 4 megabytes per sub-module + opts = diagnostics.TensorDiagnosticOptions( + 2 ** 22 + ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: train_cuts += librispeech.train_clean_360_cuts() diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 2dff91805..847328d0f 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -1,94 +1,97 @@ -import torch -from torch import Tensor -from torch import nn -import math import random -from typing import Tuple, List +from typing import List, Tuple + +import torch +from torch import Tensor, nn class TensorDiagnosticOptions(object): """ Options object for tensor diagnostics: - Args: - memory_limit: the maximum number of bytes per tensor (limits how many copies - of the tensor we cache). - + Args: + memory_limit: the maximum number of bytes per tensor (limits how many + copies of the tensor we cache). """ - def __init__(self, memory_limit: int, - print_pos_ratio: bool = True): + + def __init__(self, memory_limit: int): self.memory_limit = memory_limit - self.print_pos_ratio = print_pos_ratio def dim_is_summarized(self, size: int): return size > 10 and size != 31 - def stats_types(self): - if self.print_pos_ratio: - return ["mean-abs", "pos-ratio"] - else: - return ["mean-abs"] - - -def get_sum_abs_stats(x: Tensor, dim: int, - stats_type: str) -> Tuple[Tensor, int]: +def get_sum_abs_stats( + x: Tensor, dim: int, stats_type: str +) -> Tuple[Tensor, int]: """ - Returns the sum-of-absolute-value of this Tensor, for each - index into the specified axis/dim of the tensor. + Returns the sum-of-absolute-value of this Tensor, for each index into + the specified axis/dim of the tensor. + Args: - x: Tensor, tensor to be analyzed - dim: dimension with 0 <= dim < x.ndim - stats_type: either "mean-abs" in which case the stats represent the - mean absolute value, or "pos-ratio" in which case the - stats represent the proportion of positive values (actually: - the tensor is count of positive values, count is the count of - all values). - Returns (sum_abs, count) - where sum_abs is a Tensor of shape (x.shape[dim],), and the count - is an integer saying how many items were counted in each element - of sum_abs. + x: Tensor, tensor to be analyzed + dim: dimension with 0 <= dim < x.ndim + stats_type: either "mean-abs" in which case the stats represent the + mean absolute value, or "pos-ratio" in which case the stats represent + the proportion of positive values (actually: the tensor is count of + positive values, count is the count of all values). + + Returns (sum_abs, count): + where sum_abs is a Tensor of shape (x.shape[dim],), and the count + is an integer saying how many items were counted in each element + of sum_abs. """ if stats_type == "mean-abs": x = x.abs() else: assert stats_type == "pos-ratio" x = (x > 0).to(dtype=torch.float) + orig_numel = x.numel() - sum_dims = [ d for d in range(x.ndim) if d != dim ] + sum_dims = [d for d in range(x.ndim) if d != dim] x = torch.sum(x, dim=sum_dims) count = orig_numel // x.numel() x = x.flatten() + return x, count -def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], - options: TensorDiagnosticOptions, - sizes_same: bool, - stats_type: str): + +def get_diagnostics_for_dim( + dim: int, + tensors: List[Tensor], + options: TensorDiagnosticOptions, + sizes_same: bool, + stats_type: str, +) -> str: """ This function gets diagnostics for a dimension of a module. + Args: - dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim - options: options object - sizes_same: true if all the tensor sizes are the same on this dimension - stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats - we accumulate, mean-abs is mean absolute value, "pos-ratio" - is proportion of positive to nonnegative values. + dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim + tensors: list of cached tensors to get the stats + options: options object + sizes_same: true if all the tensor sizes are the same on this dimension + stats_type: either "mean-abs" or "pos-ratio", dictates the type of + stats we accumulate, mean-abs is mean absolute value, "pos-ratio" is + proportion of positive to nonnegative values. + Returns: - Diagnostic as a string, either percentiles or the actual values, - see the code. + Diagnostic as a string, either percentiles or the actual values, + see the code. """ + # stats_and_counts is a list of pair (Tensor, int) - stats_and_counts = [ get_sum_abs_stats(x, dim, stats_type) for x in tensors ] - stats = [ x[0] for x in stats_and_counts ] - counts = [ x[1] for x in stats_and_counts ] + stats_and_counts = [get_sum_abs_stats(x, dim, stats_type) for x in tensors] + stats = [x[0] for x in stats_and_counts] + counts = [x[1] for x in stats_and_counts] if sizes_same: stats = torch.stack(stats).sum(dim=0) count = sum(counts) stats = stats / count else: - stats = [ x[0] / x[1] for x in stats_and_counts ] + stats = [x[0] / x[1] for x in stats_and_counts] stats = torch.cat(stats, dim=0) + # if `summarize` we print percentiles of the stats; else, # we print out individual elements. summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) @@ -101,89 +104,117 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], for i in range(num_percentiles + 1): index = (i * (size - 1)) // num_percentiles percentiles.append(stats[index].item()) - percentiles = [ '%.2g' % x for x in percentiles ] - percentiles = ' '.join(percentiles) - return f'percentiles: [{percentiles}]' + percentiles = ["%.2g" % x for x in percentiles] + percentiles = " ".join(percentiles) + return f"percentiles: [{percentiles}]" else: stats = stats.tolist() - stats = [ '%.2g' % x for x in stats ] - stats = '[' + ' '.join(stats) + ']' + stats = ["%.2g" % x for x in stats] + stats = "[" + " ".join(stats) + "]" return stats +def print_diagnostics_for_dim( + name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions +): + """ + This function prints diagnostics for a dimension of a tensor. -def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], - options: TensorDiagnosticOptions): + Args: + name: the tensor name + dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim + tensors: list of cached tensors to get the stats + options: options object + """ - for stats_type in options.stats_types(): + for stats_type in ["mean-abs", "pos-ratio"]: # stats_type will be "mean-abs" or "pos-ratio". - sizes = [ x.shape[dim] for x in tensors ] - sizes_same = all([ x == sizes[0] for x in sizes ]) - s = get_diagnostics_for_dim(dim, tensors, - options, sizes_same, - stats_type) + sizes = [x.shape[dim] for x in tensors] + sizes_same = all([x == sizes[0] for x in sizes]) + s = get_diagnostics_for_dim( + dim, tensors, options, sizes_same, stats_type + ) min_size = min(sizes) max_size = max(sizes) size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}" - # stats_type will be "mean-abs" or "pos-ratio". print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}") class TensorDiagnostic(object): """ - This class is not directly used by the user, it is responsible for collecting - diagnostics for a single parameter tensor of a torch.Module. + This class is not directly used by the user, it is responsible for + collecting diagnostics for a single parameter tensor of a torch.nn.Module. + + Attributes: + opts: options object. + name: tensor name. + saved_tensors: list of cached tensors. """ - def __init__(self, - opts: TensorDiagnosticOptions, - name: str): + + def __init__(self, opts: TensorDiagnosticOptions, name: str): self.name = name self.opts = opts self.saved_tensors = [] def accumulate(self, x): + """Accumulate tensors.""" if isinstance(x, Tuple): x = x[0] if not isinstance(x, Tensor): return - if x.device == torch.device('cpu'): + if x.device == torch.device("cpu"): x = x.detach().clone() else: - x = x.detach().to('cpu', non_blocking=True) + x = x.detach().to("cpu", non_blocking=True) self.saved_tensors.append(x) - l = len(self.saved_tensors) - if l & (l - 1) == 0: # power of 2.. + num = len(self.saved_tensors) + if num & (num - 1) == 0: # power of 2.. self._limit_memory() def _limit_memory(self): + """Only keep the newly cached tensors to limit memory.""" if len(self.saved_tensors) > 1024: self.saved_tensors = self.saved_tensors[-1024:] return tot_mem = 0.0 for i in reversed(range(len(self.saved_tensors))): - tot_mem += self.saved_tensors[i].numel() * self.saved_tensors[i].element_size() + tot_mem += ( + self.saved_tensors[i].numel() + * self.saved_tensors[i].element_size() + ) if tot_mem > self.opts.memory_limit: self.saved_tensors = self.saved_tensors[i:] return def print_diagnostics(self): + """Print diagnostics for each dimension of the tensor.""" if len(self.saved_tensors) == 0: print("{name}: no stats".format(name=self.name)) return + if self.saved_tensors[0].ndim == 0: # ensure there is at least one dim. - self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ] + self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors] ndim = self.saved_tensors[0].ndim for dim in range(ndim): - print_diagnostics_for_dim(self.name, dim, - self.saved_tensors, - self.opts) + print_diagnostics_for_dim( + self.name, dim, self.saved_tensors, self.opts + ) class ModelDiagnostic(object): + """ + This class stores diagnostics for all tensors in the torch.nn.Module. + + Attributes: + diagnostics: a dictionary, whose keys are the tensors names and + the values are corresponding TensorDiagnostic objects. + opts: options object. + """ + def __init__(self, opts: TensorDiagnosticOptions): self.diagnostics = dict() self.opts = opts @@ -194,35 +225,51 @@ class ModelDiagnostic(object): return self.diagnostics[name] def print_diagnostics(self): + """Print diagnostics for each tensor.""" for k in sorted(self.diagnostics.keys()): self.diagnostics[k].print_diagnostics() +def attach_diagnostics( + model: nn.Module, opts: TensorDiagnosticOptions +) -> ModelDiagnostic: + """ + Attach a ModelDiagnostic object to the model by + 1) registering forward hook and backward hook on each module, to accumulate + its output tensors and gradient tensors, respectively; + 2) registering backward hook on each module parameter, to accumulate its + values and gradients. + + Args: + model: the model to be analyzed. + opts: options object. + + Returns: + The ModelDiagnostic object attached to the model. + """ -def attach_diagnostics(model: nn.Module, - opts: TensorDiagnosticOptions) -> ModelDiagnostic: ans = ModelDiagnostic(opts) for name, module in model.named_modules(): - if name == '': + if name == "": name = "" - forward_diagnostic = TensorDiagnostic(opts, name + ".output") - backward_diagnostic = TensorDiagnostic(opts, name + ".grad") - - # setting model_diagnostic=ans and n=name below, instead of trying to capture the variables, - # ensures that we use the current values. (matters for name, since - # the variable gets overwritten). these closures don't really capture - # by value, only by "the final value the variable got in the function" :-( - def forward_hook(_module, _input, _output, - _model_diagnostic=ans, _name=name): + # setting model_diagnostic=ans and n=name below, instead of trying to + # capture the variables, ensures that we use the current values. + # (matters for name, since the variable gets overwritten). + # these closures don't really capture by value, only by + # "the final value the variable got in the function" :-( + def forward_hook( + _module, _input, _output, _model_diagnostic=ans, _name=name + ): if isinstance(_output, Tensor): _model_diagnostic[f"{_name}.output"].accumulate(_output) elif isinstance(_output, tuple): for i, o in enumerate(_output): _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o) - def backward_hook(_module, _input, _output, - _model_diagnostic=ans, _name=name): + def backward_hook( + _module, _input, _output, _model_diagnostic=ans, _name=name + ): if isinstance(_output, Tensor): _model_diagnostic[f"{_name}.grad"].accumulate(_output) elif isinstance(_output, tuple): @@ -234,20 +281,19 @@ def attach_diagnostics(model: nn.Module, for name, parameter in model.named_parameters(): - def param_backward_hook(grad, - _parameter=parameter, - _model_diagnostic=ans, - _name=name): + def param_backward_hook( + grad, _parameter=parameter, _model_diagnostic=ans, _name=name + ): _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter) _model_diagnostic[f"{_name}.param_grad"].accumulate(grad) parameter.register_hook(param_backward_hook) + return ans - def _test_tensor_diagnostic(): - opts = TensorDiagnosticOptions(2**20, True) + opts = TensorDiagnosticOptions(2 ** 20) diagnostic = TensorDiagnostic(opts, "foo") @@ -268,17 +314,5 @@ def _test_tensor_diagnostic(): diagnostic.print_diagnostics() - -if __name__ == '__main__': +if __name__ == "__main__": _test_tensor_diagnostic() - - -def _test_func(): - ans = [] - for i in range(10): - x = list() - x.append(i) - def func(): - return x - ans.append(func) - return ans