Update docs of arguments, and remove stats_types() function in TensorDiagnosticOptions object.

This commit is contained in:
yaozengwei 2022-03-03 15:35:12 +08:00
parent 828a23daad
commit 87b4619f12
2 changed files with 149 additions and 115 deletions

View File

@ -521,7 +521,6 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 5:
return
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
@ -631,10 +630,11 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(2**22) # allow 4 megabytes per sub-module
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()

View File

@ -1,9 +1,8 @@
import torch
from torch import Tensor
from torch import nn
import math
import random
from typing import Tuple, List
from typing import List, Tuple
import torch
from torch import Tensor, nn
class TensorDiagnosticOptions(object):
@ -11,40 +10,33 @@ class TensorDiagnosticOptions(object):
Options object for tensor diagnostics:
Args:
memory_limit: the maximum number of bytes per tensor (limits how many copies
of the tensor we cache).
memory_limit: the maximum number of bytes per tensor (limits how many
copies of the tensor we cache).
"""
def __init__(self, memory_limit: int,
print_pos_ratio: bool = True):
def __init__(self, memory_limit: int):
self.memory_limit = memory_limit
self.print_pos_ratio = print_pos_ratio
def dim_is_summarized(self, size: int):
return size > 10 and size != 31
def stats_types(self):
if self.print_pos_ratio:
return ["mean-abs", "pos-ratio"]
else:
return ["mean-abs"]
def get_sum_abs_stats(x: Tensor, dim: int,
stats_type: str) -> Tuple[Tensor, int]:
def get_sum_abs_stats(
x: Tensor, dim: int, stats_type: str
) -> Tuple[Tensor, int]:
"""
Returns the sum-of-absolute-value of this Tensor, for each
index into the specified axis/dim of the tensor.
Returns the sum-of-absolute-value of this Tensor, for each index into
the specified axis/dim of the tensor.
Args:
x: Tensor, tensor to be analyzed
dim: dimension with 0 <= dim < x.ndim
stats_type: either "mean-abs" in which case the stats represent the
mean absolute value, or "pos-ratio" in which case the
stats represent the proportion of positive values (actually:
the tensor is count of positive values, count is the count of
all values).
Returns (sum_abs, count)
mean absolute value, or "pos-ratio" in which case the stats represent
the proportion of positive values (actually: the tensor is count of
positive values, count is the count of all values).
Returns (sum_abs, count):
where sum_abs is a Tensor of shape (x.shape[dim],), and the count
is an integer saying how many items were counted in each element
of sum_abs.
@ -54,30 +46,40 @@ def get_sum_abs_stats(x: Tensor, dim: int,
else:
assert stats_type == "pos-ratio"
x = (x > 0).to(dtype=torch.float)
orig_numel = x.numel()
sum_dims = [d for d in range(x.ndim) if d != dim]
x = torch.sum(x, dim=sum_dims)
count = orig_numel // x.numel()
x = x.flatten()
return x, count
def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
def get_diagnostics_for_dim(
dim: int,
tensors: List[Tensor],
options: TensorDiagnosticOptions,
sizes_same: bool,
stats_type: str):
stats_type: str,
) -> str:
"""
This function gets diagnostics for a dimension of a module.
Args:
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
tensors: list of cached tensors to get the stats
options: options object
sizes_same: true if all the tensor sizes are the same on this dimension
stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats
we accumulate, mean-abs is mean absolute value, "pos-ratio"
is proportion of positive to nonnegative values.
stats_type: either "mean-abs" or "pos-ratio", dictates the type of
stats we accumulate, mean-abs is mean absolute value, "pos-ratio" is
proportion of positive to nonnegative values.
Returns:
Diagnostic as a string, either percentiles or the actual values,
see the code.
"""
# stats_and_counts is a list of pair (Tensor, int)
stats_and_counts = [get_sum_abs_stats(x, dim, stats_type) for x in tensors]
stats = [x[0] for x in stats_and_counts]
@ -89,6 +91,7 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
else:
stats = [x[0] / x[1] for x in stats_and_counts]
stats = torch.cat(stats, dim=0)
# if `summarize` we print percentiles of the stats; else,
# we print out individual elements.
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
@ -101,89 +104,117 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
for i in range(num_percentiles + 1):
index = (i * (size - 1)) // num_percentiles
percentiles.append(stats[index].item())
percentiles = [ '%.2g' % x for x in percentiles ]
percentiles = ' '.join(percentiles)
return f'percentiles: [{percentiles}]'
percentiles = ["%.2g" % x for x in percentiles]
percentiles = " ".join(percentiles)
return f"percentiles: [{percentiles}]"
else:
stats = stats.tolist()
stats = [ '%.2g' % x for x in stats ]
stats = '[' + ' '.join(stats) + ']'
stats = ["%.2g" % x for x in stats]
stats = "[" + " ".join(stats) + "]"
return stats
def print_diagnostics_for_dim(
name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions
):
"""
This function prints diagnostics for a dimension of a tensor.
def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
options: TensorDiagnosticOptions):
Args:
name: the tensor name
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
tensors: list of cached tensors to get the stats
options: options object
"""
for stats_type in options.stats_types():
for stats_type in ["mean-abs", "pos-ratio"]:
# stats_type will be "mean-abs" or "pos-ratio".
sizes = [x.shape[dim] for x in tensors]
sizes_same = all([x == sizes[0] for x in sizes])
s = get_diagnostics_for_dim(dim, tensors,
options, sizes_same,
stats_type)
s = get_diagnostics_for_dim(
dim, tensors, options, sizes_same, stats_type
)
min_size = min(sizes)
max_size = max(sizes)
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
# stats_type will be "mean-abs" or "pos-ratio".
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
class TensorDiagnostic(object):
"""
This class is not directly used by the user, it is responsible for collecting
diagnostics for a single parameter tensor of a torch.Module.
This class is not directly used by the user, it is responsible for
collecting diagnostics for a single parameter tensor of a torch.nn.Module.
Attributes:
opts: options object.
name: tensor name.
saved_tensors: list of cached tensors.
"""
def __init__(self,
opts: TensorDiagnosticOptions,
name: str):
def __init__(self, opts: TensorDiagnosticOptions, name: str):
self.name = name
self.opts = opts
self.saved_tensors = []
def accumulate(self, x):
"""Accumulate tensors."""
if isinstance(x, Tuple):
x = x[0]
if not isinstance(x, Tensor):
return
if x.device == torch.device('cpu'):
if x.device == torch.device("cpu"):
x = x.detach().clone()
else:
x = x.detach().to('cpu', non_blocking=True)
x = x.detach().to("cpu", non_blocking=True)
self.saved_tensors.append(x)
l = len(self.saved_tensors)
if l & (l - 1) == 0: # power of 2..
num = len(self.saved_tensors)
if num & (num - 1) == 0: # power of 2..
self._limit_memory()
def _limit_memory(self):
"""Only keep the newly cached tensors to limit memory."""
if len(self.saved_tensors) > 1024:
self.saved_tensors = self.saved_tensors[-1024:]
return
tot_mem = 0.0
for i in reversed(range(len(self.saved_tensors))):
tot_mem += self.saved_tensors[i].numel() * self.saved_tensors[i].element_size()
tot_mem += (
self.saved_tensors[i].numel()
* self.saved_tensors[i].element_size()
)
if tot_mem > self.opts.memory_limit:
self.saved_tensors = self.saved_tensors[i:]
return
def print_diagnostics(self):
"""Print diagnostics for each dimension of the tensor."""
if len(self.saved_tensors) == 0:
print("{name}: no stats".format(name=self.name))
return
if self.saved_tensors[0].ndim == 0:
# ensure there is at least one dim.
self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors]
ndim = self.saved_tensors[0].ndim
for dim in range(ndim):
print_diagnostics_for_dim(self.name, dim,
self.saved_tensors,
self.opts)
print_diagnostics_for_dim(
self.name, dim, self.saved_tensors, self.opts
)
class ModelDiagnostic(object):
"""
This class stores diagnostics for all tensors in the torch.nn.Module.
Attributes:
diagnostics: a dictionary, whose keys are the tensors names and
the values are corresponding TensorDiagnostic objects.
opts: options object.
"""
def __init__(self, opts: TensorDiagnosticOptions):
self.diagnostics = dict()
self.opts = opts
@ -194,35 +225,51 @@ class ModelDiagnostic(object):
return self.diagnostics[name]
def print_diagnostics(self):
"""Print diagnostics for each tensor."""
for k in sorted(self.diagnostics.keys()):
self.diagnostics[k].print_diagnostics()
def attach_diagnostics(
model: nn.Module, opts: TensorDiagnosticOptions
) -> ModelDiagnostic:
"""
Attach a ModelDiagnostic object to the model by
1) registering forward hook and backward hook on each module, to accumulate
its output tensors and gradient tensors, respectively;
2) registering backward hook on each module parameter, to accumulate its
values and gradients.
Args:
model: the model to be analyzed.
opts: options object.
Returns:
The ModelDiagnostic object attached to the model.
"""
def attach_diagnostics(model: nn.Module,
opts: TensorDiagnosticOptions) -> ModelDiagnostic:
ans = ModelDiagnostic(opts)
for name, module in model.named_modules():
if name == '':
if name == "":
name = "<top-level>"
forward_diagnostic = TensorDiagnostic(opts, name + ".output")
backward_diagnostic = TensorDiagnostic(opts, name + ".grad")
# setting model_diagnostic=ans and n=name below, instead of trying to capture the variables,
# ensures that we use the current values. (matters for name, since
# the variable gets overwritten). these closures don't really capture
# by value, only by "the final value the variable got in the function" :-(
def forward_hook(_module, _input, _output,
_model_diagnostic=ans, _name=name):
# setting model_diagnostic=ans and n=name below, instead of trying to
# capture the variables, ensures that we use the current values.
# (matters for name, since the variable gets overwritten).
# these closures don't really capture by value, only by
# "the final value the variable got in the function" :-(
def forward_hook(
_module, _input, _output, _model_diagnostic=ans, _name=name
):
if isinstance(_output, Tensor):
_model_diagnostic[f"{_name}.output"].accumulate(_output)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o)
def backward_hook(_module, _input, _output,
_model_diagnostic=ans, _name=name):
def backward_hook(
_module, _input, _output, _model_diagnostic=ans, _name=name
):
if isinstance(_output, Tensor):
_model_diagnostic[f"{_name}.grad"].accumulate(_output)
elif isinstance(_output, tuple):
@ -234,20 +281,19 @@ def attach_diagnostics(model: nn.Module,
for name, parameter in model.named_parameters():
def param_backward_hook(grad,
_parameter=parameter,
_model_diagnostic=ans,
_name=name):
def param_backward_hook(
grad, _parameter=parameter, _model_diagnostic=ans, _name=name
):
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
parameter.register_hook(param_backward_hook)
return ans
def _test_tensor_diagnostic():
opts = TensorDiagnosticOptions(2**20, True)
opts = TensorDiagnosticOptions(2 ** 20)
diagnostic = TensorDiagnostic(opts, "foo")
@ -268,17 +314,5 @@ def _test_tensor_diagnostic():
diagnostic.print_diagnostics()
if __name__ == '__main__':
if __name__ == "__main__":
_test_tensor_diagnostic()
def _test_func():
ans = []
for i in range(10):
x = list()
x.append(i)
def func():
return x
ans.append(func)
return ans