mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Update docs of arguments.
This commit is contained in:
parent
87b4619f12
commit
8be385f3bd
@ -6,12 +6,12 @@ from torch import Tensor, nn
|
|||||||
|
|
||||||
|
|
||||||
class TensorDiagnosticOptions(object):
|
class TensorDiagnosticOptions(object):
|
||||||
"""
|
"""Options object for tensor diagnostics:
|
||||||
Options object for tensor diagnostics:
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_limit: the maximum number of bytes per tensor (limits how many
|
memory_limit:
|
||||||
copies of the tensor we cache).
|
The maximum number of bytes per tensor (limits how many copies
|
||||||
|
of the tensor we cache).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, memory_limit: int):
|
def __init__(self, memory_limit: int):
|
||||||
@ -24,22 +24,24 @@ class TensorDiagnosticOptions(object):
|
|||||||
def get_sum_abs_stats(
|
def get_sum_abs_stats(
|
||||||
x: Tensor, dim: int, stats_type: str
|
x: Tensor, dim: int, stats_type: str
|
||||||
) -> Tuple[Tensor, int]:
|
) -> Tuple[Tensor, int]:
|
||||||
"""
|
"""Returns the sum-of-absolute-value of this Tensor, for each index into
|
||||||
Returns the sum-of-absolute-value of this Tensor, for each index into
|
|
||||||
the specified axis/dim of the tensor.
|
the specified axis/dim of the tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Tensor, tensor to be analyzed
|
x:
|
||||||
dim: dimension with 0 <= dim < x.ndim
|
Tensor, tensor to be analyzed
|
||||||
stats_type: either "mean-abs" in which case the stats represent the
|
dim:
|
||||||
mean absolute value, or "pos-ratio" in which case the stats represent
|
Dimension with 0 <= dim < x.ndim
|
||||||
the proportion of positive values (actually: the tensor is count of
|
stats_type:
|
||||||
positive values, count is the count of all values).
|
Either "mean-abs" in which case the stats represent the mean absolute
|
||||||
|
value, or "pos-ratio" in which case the stats represent the proportion
|
||||||
|
of positive values (actually: the tensor is count of positive values,
|
||||||
|
count is the count of all values).
|
||||||
|
|
||||||
Returns (sum_abs, count):
|
Returns:
|
||||||
where sum_abs is a Tensor of shape (x.shape[dim],), and the count
|
(sum_abs, count) where sum_abs is a Tensor of shape (x.shape[dim],),
|
||||||
is an integer saying how many items were counted in each element
|
and the count is an integer saying how many items were counted in
|
||||||
of sum_abs.
|
each element of sum_abs.
|
||||||
"""
|
"""
|
||||||
if stats_type == "mean-abs":
|
if stats_type == "mean-abs":
|
||||||
x = x.abs()
|
x = x.abs()
|
||||||
@ -63,21 +65,24 @@ def get_diagnostics_for_dim(
|
|||||||
sizes_same: bool,
|
sizes_same: bool,
|
||||||
stats_type: str,
|
stats_type: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""This function gets diagnostics for a dimension of a module.
|
||||||
This function gets diagnostics for a dimension of a module.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
|
dim:
|
||||||
tensors: list of cached tensors to get the stats
|
The dimension to analyze, with 0 <= dim < tensors[0].ndim
|
||||||
options: options object
|
tensors:
|
||||||
sizes_same: true if all the tensor sizes are the same on this dimension
|
List of cached tensors to get the stats
|
||||||
|
options:
|
||||||
|
Options object
|
||||||
|
sizes_same:
|
||||||
|
True if all the tensor sizes are the same on this dimension
|
||||||
stats_type: either "mean-abs" or "pos-ratio", dictates the type of
|
stats_type: either "mean-abs" or "pos-ratio", dictates the type of
|
||||||
stats we accumulate, mean-abs is mean absolute value, "pos-ratio" is
|
stats we accumulate, mean-abs is mean absolute value, "pos-ratio" is
|
||||||
proportion of positive to nonnegative values.
|
proportion of positive to nonnegative values.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Diagnostic as a string, either percentiles or the actual values,
|
Diagnostic as a string, either percentiles or the actual values,
|
||||||
see the code.
|
see the code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# stats_and_counts is a list of pair (Tensor, int)
|
# stats_and_counts is a list of pair (Tensor, int)
|
||||||
@ -92,11 +97,11 @@ def get_diagnostics_for_dim(
|
|||||||
stats = [x[0] / x[1] for x in stats_and_counts]
|
stats = [x[0] / x[1] for x in stats_and_counts]
|
||||||
stats = torch.cat(stats, dim=0)
|
stats = torch.cat(stats, dim=0)
|
||||||
|
|
||||||
# if `summarize` we print percentiles of the stats; else,
|
# If `summarize` we print percentiles of the stats;
|
||||||
# we print out individual elements.
|
# else, we print out individual elements.
|
||||||
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
|
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
|
||||||
if summarize:
|
if summarize:
|
||||||
# print out percentiles.
|
# Print out percentiles.
|
||||||
stats = stats.sort()[0]
|
stats = stats.sort()[0]
|
||||||
num_percentiles = 10
|
num_percentiles = 10
|
||||||
size = stats.numel()
|
size = stats.numel()
|
||||||
@ -117,14 +122,17 @@ def get_diagnostics_for_dim(
|
|||||||
def print_diagnostics_for_dim(
|
def print_diagnostics_for_dim(
|
||||||
name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions
|
name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions
|
||||||
):
|
):
|
||||||
"""
|
"""This function prints diagnostics for a dimension of a tensor.
|
||||||
This function prints diagnostics for a dimension of a tensor.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: the tensor name
|
name:
|
||||||
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
|
The tensor name.
|
||||||
tensors: list of cached tensors to get the stats
|
dim:
|
||||||
options: options object
|
The dimension to analyze, with 0 <= dim < tensors[0].ndim.
|
||||||
|
tensors:
|
||||||
|
List of cached tensors to get the stats.
|
||||||
|
options:
|
||||||
|
Options object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for stats_type in ["mean-abs", "pos-ratio"]:
|
for stats_type in ["mean-abs", "pos-ratio"]:
|
||||||
@ -142,19 +150,20 @@ def print_diagnostics_for_dim(
|
|||||||
|
|
||||||
|
|
||||||
class TensorDiagnostic(object):
|
class TensorDiagnostic(object):
|
||||||
"""
|
"""This class is not directly used by the user, it is responsible for
|
||||||
This class is not directly used by the user, it is responsible for
|
|
||||||
collecting diagnostics for a single parameter tensor of a torch.nn.Module.
|
collecting diagnostics for a single parameter tensor of a torch.nn.Module.
|
||||||
|
|
||||||
Attributes:
|
Args:
|
||||||
opts: options object.
|
opts:
|
||||||
name: tensor name.
|
Options object.
|
||||||
saved_tensors: list of cached tensors.
|
name:
|
||||||
|
The tensor name.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, opts: TensorDiagnosticOptions, name: str):
|
def __init__(self, opts: TensorDiagnosticOptions, name: str):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.opts = opts
|
self.opts = opts
|
||||||
|
# A list to cache the tensors.
|
||||||
self.saved_tensors = []
|
self.saved_tensors = []
|
||||||
|
|
||||||
def accumulate(self, x):
|
def accumulate(self, x):
|
||||||
@ -195,7 +204,7 @@ class TensorDiagnostic(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self.saved_tensors[0].ndim == 0:
|
if self.saved_tensors[0].ndim == 0:
|
||||||
# ensure there is at least one dim.
|
# Ensure there is at least one dim.
|
||||||
self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors]
|
self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors]
|
||||||
|
|
||||||
ndim = self.saved_tensors[0].ndim
|
ndim = self.saved_tensors[0].ndim
|
||||||
@ -206,16 +215,16 @@ class TensorDiagnostic(object):
|
|||||||
|
|
||||||
|
|
||||||
class ModelDiagnostic(object):
|
class ModelDiagnostic(object):
|
||||||
"""
|
"""This class stores diagnostics for all tensors in the torch.nn.Module.
|
||||||
This class stores diagnostics for all tensors in the torch.nn.Module.
|
|
||||||
|
|
||||||
Attributes:
|
Args:
|
||||||
diagnostics: a dictionary, whose keys are the tensors names and
|
opts:
|
||||||
the values are corresponding TensorDiagnostic objects.
|
Options object.
|
||||||
opts: options object.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, opts: TensorDiagnosticOptions):
|
def __init__(self, opts: TensorDiagnosticOptions):
|
||||||
|
# In this dictionary, the keys are tensors names and the values
|
||||||
|
# are corresponding TensorDiagnostic objects.
|
||||||
self.diagnostics = dict()
|
self.diagnostics = dict()
|
||||||
self.opts = opts
|
self.opts = opts
|
||||||
|
|
||||||
@ -233,19 +242,20 @@ class ModelDiagnostic(object):
|
|||||||
def attach_diagnostics(
|
def attach_diagnostics(
|
||||||
model: nn.Module, opts: TensorDiagnosticOptions
|
model: nn.Module, opts: TensorDiagnosticOptions
|
||||||
) -> ModelDiagnostic:
|
) -> ModelDiagnostic:
|
||||||
"""
|
"""Attach a ModelDiagnostic object to the model by
|
||||||
Attach a ModelDiagnostic object to the model by
|
|
||||||
1) registering forward hook and backward hook on each module, to accumulate
|
1) registering forward hook and backward hook on each module, to accumulate
|
||||||
its output tensors and gradient tensors, respectively;
|
its output tensors and gradient tensors, respectively;
|
||||||
2) registering backward hook on each module parameter, to accumulate its
|
2) registering backward hook on each module parameter, to accumulate its
|
||||||
values and gradients.
|
values and gradients.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: the model to be analyzed.
|
model:
|
||||||
opts: options object.
|
the model to be analyzed.
|
||||||
|
opts:
|
||||||
|
Options object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The ModelDiagnostic object attached to the model.
|
The ModelDiagnostic object attached to the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ans = ModelDiagnostic(opts)
|
ans = ModelDiagnostic(opts)
|
||||||
@ -253,10 +263,10 @@ def attach_diagnostics(
|
|||||||
if name == "":
|
if name == "":
|
||||||
name = "<top-level>"
|
name = "<top-level>"
|
||||||
|
|
||||||
# setting model_diagnostic=ans and n=name below, instead of trying to
|
# Setting model_diagnostic=ans and n=name below, instead of trying to
|
||||||
# capture the variables, ensures that we use the current values.
|
# capture the variables, ensures that we use the current values.
|
||||||
# (matters for name, since the variable gets overwritten).
|
# (matters for name, since the variable gets overwritten).
|
||||||
# these closures don't really capture by value, only by
|
# These closures don't really capture by value, only by
|
||||||
# "the final value the variable got in the function" :-(
|
# "the final value the variable got in the function" :-(
|
||||||
def forward_hook(
|
def forward_hook(
|
||||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||||
|
Loading…
x
Reference in New Issue
Block a user