Update docs of arguments.

This commit is contained in:
yaozengwei 2022-03-03 17:13:33 +08:00
parent 87b4619f12
commit 8be385f3bd

View File

@ -6,12 +6,12 @@ from torch import Tensor, nn
class TensorDiagnosticOptions(object): class TensorDiagnosticOptions(object):
""" """Options object for tensor diagnostics:
Options object for tensor diagnostics:
Args: Args:
memory_limit: the maximum number of bytes per tensor (limits how many memory_limit:
copies of the tensor we cache). The maximum number of bytes per tensor (limits how many copies
of the tensor we cache).
""" """
def __init__(self, memory_limit: int): def __init__(self, memory_limit: int):
@ -24,22 +24,24 @@ class TensorDiagnosticOptions(object):
def get_sum_abs_stats( def get_sum_abs_stats(
x: Tensor, dim: int, stats_type: str x: Tensor, dim: int, stats_type: str
) -> Tuple[Tensor, int]: ) -> Tuple[Tensor, int]:
""" """Returns the sum-of-absolute-value of this Tensor, for each index into
Returns the sum-of-absolute-value of this Tensor, for each index into
the specified axis/dim of the tensor. the specified axis/dim of the tensor.
Args: Args:
x: Tensor, tensor to be analyzed x:
dim: dimension with 0 <= dim < x.ndim Tensor, tensor to be analyzed
stats_type: either "mean-abs" in which case the stats represent the dim:
mean absolute value, or "pos-ratio" in which case the stats represent Dimension with 0 <= dim < x.ndim
the proportion of positive values (actually: the tensor is count of stats_type:
positive values, count is the count of all values). Either "mean-abs" in which case the stats represent the mean absolute
value, or "pos-ratio" in which case the stats represent the proportion
of positive values (actually: the tensor is count of positive values,
count is the count of all values).
Returns (sum_abs, count): Returns:
where sum_abs is a Tensor of shape (x.shape[dim],), and the count (sum_abs, count) where sum_abs is a Tensor of shape (x.shape[dim],),
is an integer saying how many items were counted in each element and the count is an integer saying how many items were counted in
of sum_abs. each element of sum_abs.
""" """
if stats_type == "mean-abs": if stats_type == "mean-abs":
x = x.abs() x = x.abs()
@ -63,21 +65,24 @@ def get_diagnostics_for_dim(
sizes_same: bool, sizes_same: bool,
stats_type: str, stats_type: str,
) -> str: ) -> str:
""" """This function gets diagnostics for a dimension of a module.
This function gets diagnostics for a dimension of a module.
Args: Args:
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim dim:
tensors: list of cached tensors to get the stats The dimension to analyze, with 0 <= dim < tensors[0].ndim
options: options object tensors:
sizes_same: true if all the tensor sizes are the same on this dimension List of cached tensors to get the stats
options:
Options object
sizes_same:
True if all the tensor sizes are the same on this dimension
stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats_type: either "mean-abs" or "pos-ratio", dictates the type of
stats we accumulate, mean-abs is mean absolute value, "pos-ratio" is stats we accumulate, mean-abs is mean absolute value, "pos-ratio" is
proportion of positive to nonnegative values. proportion of positive to nonnegative values.
Returns: Returns:
Diagnostic as a string, either percentiles or the actual values, Diagnostic as a string, either percentiles or the actual values,
see the code. see the code.
""" """
# stats_and_counts is a list of pair (Tensor, int) # stats_and_counts is a list of pair (Tensor, int)
@ -92,11 +97,11 @@ def get_diagnostics_for_dim(
stats = [x[0] / x[1] for x in stats_and_counts] stats = [x[0] / x[1] for x in stats_and_counts]
stats = torch.cat(stats, dim=0) stats = torch.cat(stats, dim=0)
# if `summarize` we print percentiles of the stats; else, # If `summarize` we print percentiles of the stats;
# we print out individual elements. # else, we print out individual elements.
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
if summarize: if summarize:
# print out percentiles. # Print out percentiles.
stats = stats.sort()[0] stats = stats.sort()[0]
num_percentiles = 10 num_percentiles = 10
size = stats.numel() size = stats.numel()
@ -117,14 +122,17 @@ def get_diagnostics_for_dim(
def print_diagnostics_for_dim( def print_diagnostics_for_dim(
name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions
): ):
""" """This function prints diagnostics for a dimension of a tensor.
This function prints diagnostics for a dimension of a tensor.
Args: Args:
name: the tensor name name:
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim The tensor name.
tensors: list of cached tensors to get the stats dim:
options: options object The dimension to analyze, with 0 <= dim < tensors[0].ndim.
tensors:
List of cached tensors to get the stats.
options:
Options object.
""" """
for stats_type in ["mean-abs", "pos-ratio"]: for stats_type in ["mean-abs", "pos-ratio"]:
@ -142,19 +150,20 @@ def print_diagnostics_for_dim(
class TensorDiagnostic(object): class TensorDiagnostic(object):
""" """This class is not directly used by the user, it is responsible for
This class is not directly used by the user, it is responsible for
collecting diagnostics for a single parameter tensor of a torch.nn.Module. collecting diagnostics for a single parameter tensor of a torch.nn.Module.
Attributes: Args:
opts: options object. opts:
name: tensor name. Options object.
saved_tensors: list of cached tensors. name:
The tensor name.
""" """
def __init__(self, opts: TensorDiagnosticOptions, name: str): def __init__(self, opts: TensorDiagnosticOptions, name: str):
self.name = name self.name = name
self.opts = opts self.opts = opts
# A list to cache the tensors.
self.saved_tensors = [] self.saved_tensors = []
def accumulate(self, x): def accumulate(self, x):
@ -195,7 +204,7 @@ class TensorDiagnostic(object):
return return
if self.saved_tensors[0].ndim == 0: if self.saved_tensors[0].ndim == 0:
# ensure there is at least one dim. # Ensure there is at least one dim.
self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors] self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors]
ndim = self.saved_tensors[0].ndim ndim = self.saved_tensors[0].ndim
@ -206,16 +215,16 @@ class TensorDiagnostic(object):
class ModelDiagnostic(object): class ModelDiagnostic(object):
""" """This class stores diagnostics for all tensors in the torch.nn.Module.
This class stores diagnostics for all tensors in the torch.nn.Module.
Attributes: Args:
diagnostics: a dictionary, whose keys are the tensors names and opts:
the values are corresponding TensorDiagnostic objects. Options object.
opts: options object.
""" """
def __init__(self, opts: TensorDiagnosticOptions): def __init__(self, opts: TensorDiagnosticOptions):
# In this dictionary, the keys are tensors names and the values
# are corresponding TensorDiagnostic objects.
self.diagnostics = dict() self.diagnostics = dict()
self.opts = opts self.opts = opts
@ -233,19 +242,20 @@ class ModelDiagnostic(object):
def attach_diagnostics( def attach_diagnostics(
model: nn.Module, opts: TensorDiagnosticOptions model: nn.Module, opts: TensorDiagnosticOptions
) -> ModelDiagnostic: ) -> ModelDiagnostic:
""" """Attach a ModelDiagnostic object to the model by
Attach a ModelDiagnostic object to the model by
1) registering forward hook and backward hook on each module, to accumulate 1) registering forward hook and backward hook on each module, to accumulate
its output tensors and gradient tensors, respectively; its output tensors and gradient tensors, respectively;
2) registering backward hook on each module parameter, to accumulate its 2) registering backward hook on each module parameter, to accumulate its
values and gradients. values and gradients.
Args: Args:
model: the model to be analyzed. model:
opts: options object. the model to be analyzed.
opts:
Options object.
Returns: Returns:
The ModelDiagnostic object attached to the model. The ModelDiagnostic object attached to the model.
""" """
ans = ModelDiagnostic(opts) ans = ModelDiagnostic(opts)
@ -253,10 +263,10 @@ def attach_diagnostics(
if name == "": if name == "":
name = "<top-level>" name = "<top-level>"
# setting model_diagnostic=ans and n=name below, instead of trying to # Setting model_diagnostic=ans and n=name below, instead of trying to
# capture the variables, ensures that we use the current values. # capture the variables, ensures that we use the current values.
# (matters for name, since the variable gets overwritten). # (matters for name, since the variable gets overwritten).
# these closures don't really capture by value, only by # These closures don't really capture by value, only by
# "the final value the variable got in the function" :-( # "the final value the variable got in the function" :-(
def forward_hook( def forward_hook(
_module, _input, _output, _model_diagnostic=ans, _name=name _module, _input, _output, _model_diagnostic=ans, _name=name