Update docs of arguments, and remove stats_types() function in TensorDiagnosticOptions object.

This commit is contained in:
yaozengwei 2022-03-03 15:35:12 +08:00
parent 828a23daad
commit 87b4619f12
2 changed files with 149 additions and 115 deletions

View File

@ -521,7 +521,6 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
@ -631,10 +630,11 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(2**22) # allow 4 megabytes per sub-module opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_clean_360_cuts()

View File

@ -1,94 +1,97 @@
import torch
from torch import Tensor
from torch import nn
import math
import random import random
from typing import Tuple, List from typing import List, Tuple
import torch
from torch import Tensor, nn
class TensorDiagnosticOptions(object): class TensorDiagnosticOptions(object):
""" """
Options object for tensor diagnostics: Options object for tensor diagnostics:
Args: Args:
memory_limit: the maximum number of bytes per tensor (limits how many copies memory_limit: the maximum number of bytes per tensor (limits how many
of the tensor we cache). copies of the tensor we cache).
""" """
def __init__(self, memory_limit: int,
print_pos_ratio: bool = True): def __init__(self, memory_limit: int):
self.memory_limit = memory_limit self.memory_limit = memory_limit
self.print_pos_ratio = print_pos_ratio
def dim_is_summarized(self, size: int): def dim_is_summarized(self, size: int):
return size > 10 and size != 31 return size > 10 and size != 31
def stats_types(self):
if self.print_pos_ratio:
return ["mean-abs", "pos-ratio"]
else:
return ["mean-abs"]
def get_sum_abs_stats(
x: Tensor, dim: int, stats_type: str
def get_sum_abs_stats(x: Tensor, dim: int, ) -> Tuple[Tensor, int]:
stats_type: str) -> Tuple[Tensor, int]:
""" """
Returns the sum-of-absolute-value of this Tensor, for each Returns the sum-of-absolute-value of this Tensor, for each index into
index into the specified axis/dim of the tensor. the specified axis/dim of the tensor.
Args: Args:
x: Tensor, tensor to be analyzed x: Tensor, tensor to be analyzed
dim: dimension with 0 <= dim < x.ndim dim: dimension with 0 <= dim < x.ndim
stats_type: either "mean-abs" in which case the stats represent the stats_type: either "mean-abs" in which case the stats represent the
mean absolute value, or "pos-ratio" in which case the mean absolute value, or "pos-ratio" in which case the stats represent
stats represent the proportion of positive values (actually: the proportion of positive values (actually: the tensor is count of
the tensor is count of positive values, count is the count of positive values, count is the count of all values).
all values).
Returns (sum_abs, count) Returns (sum_abs, count):
where sum_abs is a Tensor of shape (x.shape[dim],), and the count where sum_abs is a Tensor of shape (x.shape[dim],), and the count
is an integer saying how many items were counted in each element is an integer saying how many items were counted in each element
of sum_abs. of sum_abs.
""" """
if stats_type == "mean-abs": if stats_type == "mean-abs":
x = x.abs() x = x.abs()
else: else:
assert stats_type == "pos-ratio" assert stats_type == "pos-ratio"
x = (x > 0).to(dtype=torch.float) x = (x > 0).to(dtype=torch.float)
orig_numel = x.numel() orig_numel = x.numel()
sum_dims = [ d for d in range(x.ndim) if d != dim ] sum_dims = [d for d in range(x.ndim) if d != dim]
x = torch.sum(x, dim=sum_dims) x = torch.sum(x, dim=sum_dims)
count = orig_numel // x.numel() count = orig_numel // x.numel()
x = x.flatten() x = x.flatten()
return x, count return x, count
def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
options: TensorDiagnosticOptions, def get_diagnostics_for_dim(
sizes_same: bool, dim: int,
stats_type: str): tensors: List[Tensor],
options: TensorDiagnosticOptions,
sizes_same: bool,
stats_type: str,
) -> str:
""" """
This function gets diagnostics for a dimension of a module. This function gets diagnostics for a dimension of a module.
Args: Args:
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
options: options object tensors: list of cached tensors to get the stats
sizes_same: true if all the tensor sizes are the same on this dimension options: options object
stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats sizes_same: true if all the tensor sizes are the same on this dimension
we accumulate, mean-abs is mean absolute value, "pos-ratio" stats_type: either "mean-abs" or "pos-ratio", dictates the type of
is proportion of positive to nonnegative values. stats we accumulate, mean-abs is mean absolute value, "pos-ratio" is
proportion of positive to nonnegative values.
Returns: Returns:
Diagnostic as a string, either percentiles or the actual values, Diagnostic as a string, either percentiles or the actual values,
see the code. see the code.
""" """
# stats_and_counts is a list of pair (Tensor, int) # stats_and_counts is a list of pair (Tensor, int)
stats_and_counts = [ get_sum_abs_stats(x, dim, stats_type) for x in tensors ] stats_and_counts = [get_sum_abs_stats(x, dim, stats_type) for x in tensors]
stats = [ x[0] for x in stats_and_counts ] stats = [x[0] for x in stats_and_counts]
counts = [ x[1] for x in stats_and_counts ] counts = [x[1] for x in stats_and_counts]
if sizes_same: if sizes_same:
stats = torch.stack(stats).sum(dim=0) stats = torch.stack(stats).sum(dim=0)
count = sum(counts) count = sum(counts)
stats = stats / count stats = stats / count
else: else:
stats = [ x[0] / x[1] for x in stats_and_counts ] stats = [x[0] / x[1] for x in stats_and_counts]
stats = torch.cat(stats, dim=0) stats = torch.cat(stats, dim=0)
# if `summarize` we print percentiles of the stats; else, # if `summarize` we print percentiles of the stats; else,
# we print out individual elements. # we print out individual elements.
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
@ -101,89 +104,117 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
for i in range(num_percentiles + 1): for i in range(num_percentiles + 1):
index = (i * (size - 1)) // num_percentiles index = (i * (size - 1)) // num_percentiles
percentiles.append(stats[index].item()) percentiles.append(stats[index].item())
percentiles = [ '%.2g' % x for x in percentiles ] percentiles = ["%.2g" % x for x in percentiles]
percentiles = ' '.join(percentiles) percentiles = " ".join(percentiles)
return f'percentiles: [{percentiles}]' return f"percentiles: [{percentiles}]"
else: else:
stats = stats.tolist() stats = stats.tolist()
stats = [ '%.2g' % x for x in stats ] stats = ["%.2g" % x for x in stats]
stats = '[' + ' '.join(stats) + ']' stats = "[" + " ".join(stats) + "]"
return stats return stats
def print_diagnostics_for_dim(
name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions
):
"""
This function prints diagnostics for a dimension of a tensor.
def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], Args:
options: TensorDiagnosticOptions): name: the tensor name
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
tensors: list of cached tensors to get the stats
options: options object
"""
for stats_type in options.stats_types(): for stats_type in ["mean-abs", "pos-ratio"]:
# stats_type will be "mean-abs" or "pos-ratio". # stats_type will be "mean-abs" or "pos-ratio".
sizes = [ x.shape[dim] for x in tensors ] sizes = [x.shape[dim] for x in tensors]
sizes_same = all([ x == sizes[0] for x in sizes ]) sizes_same = all([x == sizes[0] for x in sizes])
s = get_diagnostics_for_dim(dim, tensors, s = get_diagnostics_for_dim(
options, sizes_same, dim, tensors, options, sizes_same, stats_type
stats_type) )
min_size = min(sizes) min_size = min(sizes)
max_size = max(sizes) max_size = max(sizes)
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}" size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
# stats_type will be "mean-abs" or "pos-ratio".
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}") print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
class TensorDiagnostic(object): class TensorDiagnostic(object):
""" """
This class is not directly used by the user, it is responsible for collecting This class is not directly used by the user, it is responsible for
diagnostics for a single parameter tensor of a torch.Module. collecting diagnostics for a single parameter tensor of a torch.nn.Module.
Attributes:
opts: options object.
name: tensor name.
saved_tensors: list of cached tensors.
""" """
def __init__(self,
opts: TensorDiagnosticOptions, def __init__(self, opts: TensorDiagnosticOptions, name: str):
name: str):
self.name = name self.name = name
self.opts = opts self.opts = opts
self.saved_tensors = [] self.saved_tensors = []
def accumulate(self, x): def accumulate(self, x):
"""Accumulate tensors."""
if isinstance(x, Tuple): if isinstance(x, Tuple):
x = x[0] x = x[0]
if not isinstance(x, Tensor): if not isinstance(x, Tensor):
return return
if x.device == torch.device('cpu'): if x.device == torch.device("cpu"):
x = x.detach().clone() x = x.detach().clone()
else: else:
x = x.detach().to('cpu', non_blocking=True) x = x.detach().to("cpu", non_blocking=True)
self.saved_tensors.append(x) self.saved_tensors.append(x)
l = len(self.saved_tensors) num = len(self.saved_tensors)
if l & (l - 1) == 0: # power of 2.. if num & (num - 1) == 0: # power of 2..
self._limit_memory() self._limit_memory()
def _limit_memory(self): def _limit_memory(self):
"""Only keep the newly cached tensors to limit memory."""
if len(self.saved_tensors) > 1024: if len(self.saved_tensors) > 1024:
self.saved_tensors = self.saved_tensors[-1024:] self.saved_tensors = self.saved_tensors[-1024:]
return return
tot_mem = 0.0 tot_mem = 0.0
for i in reversed(range(len(self.saved_tensors))): for i in reversed(range(len(self.saved_tensors))):
tot_mem += self.saved_tensors[i].numel() * self.saved_tensors[i].element_size() tot_mem += (
self.saved_tensors[i].numel()
* self.saved_tensors[i].element_size()
)
if tot_mem > self.opts.memory_limit: if tot_mem > self.opts.memory_limit:
self.saved_tensors = self.saved_tensors[i:] self.saved_tensors = self.saved_tensors[i:]
return return
def print_diagnostics(self): def print_diagnostics(self):
"""Print diagnostics for each dimension of the tensor."""
if len(self.saved_tensors) == 0: if len(self.saved_tensors) == 0:
print("{name}: no stats".format(name=self.name)) print("{name}: no stats".format(name=self.name))
return return
if self.saved_tensors[0].ndim == 0: if self.saved_tensors[0].ndim == 0:
# ensure there is at least one dim. # ensure there is at least one dim.
self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ] self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors]
ndim = self.saved_tensors[0].ndim ndim = self.saved_tensors[0].ndim
for dim in range(ndim): for dim in range(ndim):
print_diagnostics_for_dim(self.name, dim, print_diagnostics_for_dim(
self.saved_tensors, self.name, dim, self.saved_tensors, self.opts
self.opts) )
class ModelDiagnostic(object): class ModelDiagnostic(object):
"""
This class stores diagnostics for all tensors in the torch.nn.Module.
Attributes:
diagnostics: a dictionary, whose keys are the tensors names and
the values are corresponding TensorDiagnostic objects.
opts: options object.
"""
def __init__(self, opts: TensorDiagnosticOptions): def __init__(self, opts: TensorDiagnosticOptions):
self.diagnostics = dict() self.diagnostics = dict()
self.opts = opts self.opts = opts
@ -194,35 +225,51 @@ class ModelDiagnostic(object):
return self.diagnostics[name] return self.diagnostics[name]
def print_diagnostics(self): def print_diagnostics(self):
"""Print diagnostics for each tensor."""
for k in sorted(self.diagnostics.keys()): for k in sorted(self.diagnostics.keys()):
self.diagnostics[k].print_diagnostics() self.diagnostics[k].print_diagnostics()
def attach_diagnostics(
model: nn.Module, opts: TensorDiagnosticOptions
) -> ModelDiagnostic:
"""
Attach a ModelDiagnostic object to the model by
1) registering forward hook and backward hook on each module, to accumulate
its output tensors and gradient tensors, respectively;
2) registering backward hook on each module parameter, to accumulate its
values and gradients.
Args:
model: the model to be analyzed.
opts: options object.
Returns:
The ModelDiagnostic object attached to the model.
"""
def attach_diagnostics(model: nn.Module,
opts: TensorDiagnosticOptions) -> ModelDiagnostic:
ans = ModelDiagnostic(opts) ans = ModelDiagnostic(opts)
for name, module in model.named_modules(): for name, module in model.named_modules():
if name == '': if name == "":
name = "<top-level>" name = "<top-level>"
forward_diagnostic = TensorDiagnostic(opts, name + ".output")
backward_diagnostic = TensorDiagnostic(opts, name + ".grad")
# setting model_diagnostic=ans and n=name below, instead of trying to
# setting model_diagnostic=ans and n=name below, instead of trying to capture the variables, # capture the variables, ensures that we use the current values.
# ensures that we use the current values. (matters for name, since # (matters for name, since the variable gets overwritten).
# the variable gets overwritten). these closures don't really capture # these closures don't really capture by value, only by
# by value, only by "the final value the variable got in the function" :-( # "the final value the variable got in the function" :-(
def forward_hook(_module, _input, _output, def forward_hook(
_model_diagnostic=ans, _name=name): _module, _input, _output, _model_diagnostic=ans, _name=name
):
if isinstance(_output, Tensor): if isinstance(_output, Tensor):
_model_diagnostic[f"{_name}.output"].accumulate(_output) _model_diagnostic[f"{_name}.output"].accumulate(_output)
elif isinstance(_output, tuple): elif isinstance(_output, tuple):
for i, o in enumerate(_output): for i, o in enumerate(_output):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o) _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o)
def backward_hook(_module, _input, _output, def backward_hook(
_model_diagnostic=ans, _name=name): _module, _input, _output, _model_diagnostic=ans, _name=name
):
if isinstance(_output, Tensor): if isinstance(_output, Tensor):
_model_diagnostic[f"{_name}.grad"].accumulate(_output) _model_diagnostic[f"{_name}.grad"].accumulate(_output)
elif isinstance(_output, tuple): elif isinstance(_output, tuple):
@ -234,20 +281,19 @@ def attach_diagnostics(model: nn.Module,
for name, parameter in model.named_parameters(): for name, parameter in model.named_parameters():
def param_backward_hook(grad, def param_backward_hook(
_parameter=parameter, grad, _parameter=parameter, _model_diagnostic=ans, _name=name
_model_diagnostic=ans, ):
_name=name):
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter) _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad) _model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
parameter.register_hook(param_backward_hook) parameter.register_hook(param_backward_hook)
return ans return ans
def _test_tensor_diagnostic(): def _test_tensor_diagnostic():
opts = TensorDiagnosticOptions(2**20, True) opts = TensorDiagnosticOptions(2 ** 20)
diagnostic = TensorDiagnostic(opts, "foo") diagnostic = TensorDiagnostic(opts, "foo")
@ -268,17 +314,5 @@ def _test_tensor_diagnostic():
diagnostic.print_diagnostics() diagnostic.print_diagnostics()
if __name__ == "__main__":
if __name__ == '__main__':
_test_tensor_diagnostic() _test_tensor_diagnostic()
def _test_func():
ans = []
for i in range(10):
x = list()
x.append(i)
def func():
return x
ans.append(func)
return ans