diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 847328d0f..d3c13703f 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -6,12 +6,12 @@ from torch import Tensor, nn class TensorDiagnosticOptions(object): - """ - Options object for tensor diagnostics: + """Options object for tensor diagnostics: Args: - memory_limit: the maximum number of bytes per tensor (limits how many - copies of the tensor we cache). + memory_limit: + The maximum number of bytes per tensor (limits how many copies + of the tensor we cache). """ def __init__(self, memory_limit: int): @@ -24,22 +24,24 @@ class TensorDiagnosticOptions(object): def get_sum_abs_stats( x: Tensor, dim: int, stats_type: str ) -> Tuple[Tensor, int]: - """ - Returns the sum-of-absolute-value of this Tensor, for each index into + """Returns the sum-of-absolute-value of this Tensor, for each index into the specified axis/dim of the tensor. Args: - x: Tensor, tensor to be analyzed - dim: dimension with 0 <= dim < x.ndim - stats_type: either "mean-abs" in which case the stats represent the - mean absolute value, or "pos-ratio" in which case the stats represent - the proportion of positive values (actually: the tensor is count of - positive values, count is the count of all values). + x: + Tensor, tensor to be analyzed + dim: + Dimension with 0 <= dim < x.ndim + stats_type: + Either "mean-abs" in which case the stats represent the mean absolute + value, or "pos-ratio" in which case the stats represent the proportion + of positive values (actually: the tensor is count of positive values, + count is the count of all values). - Returns (sum_abs, count): - where sum_abs is a Tensor of shape (x.shape[dim],), and the count - is an integer saying how many items were counted in each element - of sum_abs. + Returns: + (sum_abs, count) where sum_abs is a Tensor of shape (x.shape[dim],), + and the count is an integer saying how many items were counted in + each element of sum_abs. """ if stats_type == "mean-abs": x = x.abs() @@ -63,21 +65,24 @@ def get_diagnostics_for_dim( sizes_same: bool, stats_type: str, ) -> str: - """ - This function gets diagnostics for a dimension of a module. + """This function gets diagnostics for a dimension of a module. Args: - dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim - tensors: list of cached tensors to get the stats - options: options object - sizes_same: true if all the tensor sizes are the same on this dimension + dim: + The dimension to analyze, with 0 <= dim < tensors[0].ndim + tensors: + List of cached tensors to get the stats + options: + Options object + sizes_same: + True if all the tensor sizes are the same on this dimension stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats we accumulate, mean-abs is mean absolute value, "pos-ratio" is proportion of positive to nonnegative values. Returns: - Diagnostic as a string, either percentiles or the actual values, - see the code. + Diagnostic as a string, either percentiles or the actual values, + see the code. """ # stats_and_counts is a list of pair (Tensor, int) @@ -92,11 +97,11 @@ def get_diagnostics_for_dim( stats = [x[0] / x[1] for x in stats_and_counts] stats = torch.cat(stats, dim=0) - # if `summarize` we print percentiles of the stats; else, - # we print out individual elements. + # If `summarize` we print percentiles of the stats; + # else, we print out individual elements. summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) if summarize: - # print out percentiles. + # Print out percentiles. stats = stats.sort()[0] num_percentiles = 10 size = stats.numel() @@ -117,14 +122,17 @@ def get_diagnostics_for_dim( def print_diagnostics_for_dim( name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions ): - """ - This function prints diagnostics for a dimension of a tensor. + """This function prints diagnostics for a dimension of a tensor. Args: - name: the tensor name - dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim - tensors: list of cached tensors to get the stats - options: options object + name: + The tensor name. + dim: + The dimension to analyze, with 0 <= dim < tensors[0].ndim. + tensors: + List of cached tensors to get the stats. + options: + Options object. """ for stats_type in ["mean-abs", "pos-ratio"]: @@ -142,19 +150,20 @@ def print_diagnostics_for_dim( class TensorDiagnostic(object): - """ - This class is not directly used by the user, it is responsible for + """This class is not directly used by the user, it is responsible for collecting diagnostics for a single parameter tensor of a torch.nn.Module. - Attributes: - opts: options object. - name: tensor name. - saved_tensors: list of cached tensors. + Args: + opts: + Options object. + name: + The tensor name. """ def __init__(self, opts: TensorDiagnosticOptions, name: str): self.name = name self.opts = opts + # A list to cache the tensors. self.saved_tensors = [] def accumulate(self, x): @@ -195,7 +204,7 @@ class TensorDiagnostic(object): return if self.saved_tensors[0].ndim == 0: - # ensure there is at least one dim. + # Ensure there is at least one dim. self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors] ndim = self.saved_tensors[0].ndim @@ -206,16 +215,16 @@ class TensorDiagnostic(object): class ModelDiagnostic(object): - """ - This class stores diagnostics for all tensors in the torch.nn.Module. + """This class stores diagnostics for all tensors in the torch.nn.Module. - Attributes: - diagnostics: a dictionary, whose keys are the tensors names and - the values are corresponding TensorDiagnostic objects. - opts: options object. + Args: + opts: + Options object. """ def __init__(self, opts: TensorDiagnosticOptions): + # In this dictionary, the keys are tensors names and the values + # are corresponding TensorDiagnostic objects. self.diagnostics = dict() self.opts = opts @@ -233,19 +242,20 @@ class ModelDiagnostic(object): def attach_diagnostics( model: nn.Module, opts: TensorDiagnosticOptions ) -> ModelDiagnostic: - """ - Attach a ModelDiagnostic object to the model by + """Attach a ModelDiagnostic object to the model by 1) registering forward hook and backward hook on each module, to accumulate its output tensors and gradient tensors, respectively; 2) registering backward hook on each module parameter, to accumulate its values and gradients. Args: - model: the model to be analyzed. - opts: options object. + model: + the model to be analyzed. + opts: + Options object. Returns: - The ModelDiagnostic object attached to the model. + The ModelDiagnostic object attached to the model. """ ans = ModelDiagnostic(opts) @@ -253,10 +263,10 @@ def attach_diagnostics( if name == "": name = "" - # setting model_diagnostic=ans and n=name below, instead of trying to + # Setting model_diagnostic=ans and n=name below, instead of trying to # capture the variables, ensures that we use the current values. # (matters for name, since the variable gets overwritten). - # these closures don't really capture by value, only by + # These closures don't really capture by value, only by # "the final value the variable got in the function" :-( def forward_hook( _module, _input, _output, _model_diagnostic=ans, _name=name