mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Update docs of arguments, and remove stats_types() function in TensorDiagnosticOptions object.
This commit is contained in:
parent
828a23daad
commit
87b4619f12
@ -521,7 +521,6 @@ def train_one_epoch(
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
@ -631,10 +630,11 @@ def run(rank, world_size, args):
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(2**22) # allow 4 megabytes per sub-module
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts()
|
||||
|
@ -1,94 +1,97 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
import math
|
||||
import random
|
||||
from typing import Tuple, List
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class TensorDiagnosticOptions(object):
|
||||
"""
|
||||
Options object for tensor diagnostics:
|
||||
|
||||
Args:
|
||||
memory_limit: the maximum number of bytes per tensor (limits how many copies
|
||||
of the tensor we cache).
|
||||
|
||||
Args:
|
||||
memory_limit: the maximum number of bytes per tensor (limits how many
|
||||
copies of the tensor we cache).
|
||||
"""
|
||||
def __init__(self, memory_limit: int,
|
||||
print_pos_ratio: bool = True):
|
||||
|
||||
def __init__(self, memory_limit: int):
|
||||
self.memory_limit = memory_limit
|
||||
self.print_pos_ratio = print_pos_ratio
|
||||
|
||||
def dim_is_summarized(self, size: int):
|
||||
return size > 10 and size != 31
|
||||
|
||||
def stats_types(self):
|
||||
if self.print_pos_ratio:
|
||||
return ["mean-abs", "pos-ratio"]
|
||||
else:
|
||||
return ["mean-abs"]
|
||||
|
||||
|
||||
|
||||
def get_sum_abs_stats(x: Tensor, dim: int,
|
||||
stats_type: str) -> Tuple[Tensor, int]:
|
||||
def get_sum_abs_stats(
|
||||
x: Tensor, dim: int, stats_type: str
|
||||
) -> Tuple[Tensor, int]:
|
||||
"""
|
||||
Returns the sum-of-absolute-value of this Tensor, for each
|
||||
index into the specified axis/dim of the tensor.
|
||||
Returns the sum-of-absolute-value of this Tensor, for each index into
|
||||
the specified axis/dim of the tensor.
|
||||
|
||||
Args:
|
||||
x: Tensor, tensor to be analyzed
|
||||
dim: dimension with 0 <= dim < x.ndim
|
||||
stats_type: either "mean-abs" in which case the stats represent the
|
||||
mean absolute value, or "pos-ratio" in which case the
|
||||
stats represent the proportion of positive values (actually:
|
||||
the tensor is count of positive values, count is the count of
|
||||
all values).
|
||||
Returns (sum_abs, count)
|
||||
where sum_abs is a Tensor of shape (x.shape[dim],), and the count
|
||||
is an integer saying how many items were counted in each element
|
||||
of sum_abs.
|
||||
x: Tensor, tensor to be analyzed
|
||||
dim: dimension with 0 <= dim < x.ndim
|
||||
stats_type: either "mean-abs" in which case the stats represent the
|
||||
mean absolute value, or "pos-ratio" in which case the stats represent
|
||||
the proportion of positive values (actually: the tensor is count of
|
||||
positive values, count is the count of all values).
|
||||
|
||||
Returns (sum_abs, count):
|
||||
where sum_abs is a Tensor of shape (x.shape[dim],), and the count
|
||||
is an integer saying how many items were counted in each element
|
||||
of sum_abs.
|
||||
"""
|
||||
if stats_type == "mean-abs":
|
||||
x = x.abs()
|
||||
else:
|
||||
assert stats_type == "pos-ratio"
|
||||
x = (x > 0).to(dtype=torch.float)
|
||||
|
||||
orig_numel = x.numel()
|
||||
sum_dims = [ d for d in range(x.ndim) if d != dim ]
|
||||
sum_dims = [d for d in range(x.ndim) if d != dim]
|
||||
x = torch.sum(x, dim=sum_dims)
|
||||
count = orig_numel // x.numel()
|
||||
x = x.flatten()
|
||||
|
||||
return x, count
|
||||
|
||||
def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
|
||||
options: TensorDiagnosticOptions,
|
||||
sizes_same: bool,
|
||||
stats_type: str):
|
||||
|
||||
def get_diagnostics_for_dim(
|
||||
dim: int,
|
||||
tensors: List[Tensor],
|
||||
options: TensorDiagnosticOptions,
|
||||
sizes_same: bool,
|
||||
stats_type: str,
|
||||
) -> str:
|
||||
"""
|
||||
This function gets diagnostics for a dimension of a module.
|
||||
|
||||
Args:
|
||||
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
|
||||
options: options object
|
||||
sizes_same: true if all the tensor sizes are the same on this dimension
|
||||
stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats
|
||||
we accumulate, mean-abs is mean absolute value, "pos-ratio"
|
||||
is proportion of positive to nonnegative values.
|
||||
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
|
||||
tensors: list of cached tensors to get the stats
|
||||
options: options object
|
||||
sizes_same: true if all the tensor sizes are the same on this dimension
|
||||
stats_type: either "mean-abs" or "pos-ratio", dictates the type of
|
||||
stats we accumulate, mean-abs is mean absolute value, "pos-ratio" is
|
||||
proportion of positive to nonnegative values.
|
||||
|
||||
Returns:
|
||||
Diagnostic as a string, either percentiles or the actual values,
|
||||
see the code.
|
||||
Diagnostic as a string, either percentiles or the actual values,
|
||||
see the code.
|
||||
"""
|
||||
|
||||
# stats_and_counts is a list of pair (Tensor, int)
|
||||
stats_and_counts = [ get_sum_abs_stats(x, dim, stats_type) for x in tensors ]
|
||||
stats = [ x[0] for x in stats_and_counts ]
|
||||
counts = [ x[1] for x in stats_and_counts ]
|
||||
stats_and_counts = [get_sum_abs_stats(x, dim, stats_type) for x in tensors]
|
||||
stats = [x[0] for x in stats_and_counts]
|
||||
counts = [x[1] for x in stats_and_counts]
|
||||
if sizes_same:
|
||||
stats = torch.stack(stats).sum(dim=0)
|
||||
count = sum(counts)
|
||||
stats = stats / count
|
||||
else:
|
||||
stats = [ x[0] / x[1] for x in stats_and_counts ]
|
||||
stats = [x[0] / x[1] for x in stats_and_counts]
|
||||
stats = torch.cat(stats, dim=0)
|
||||
|
||||
# if `summarize` we print percentiles of the stats; else,
|
||||
# we print out individual elements.
|
||||
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
|
||||
@ -101,89 +104,117 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
|
||||
for i in range(num_percentiles + 1):
|
||||
index = (i * (size - 1)) // num_percentiles
|
||||
percentiles.append(stats[index].item())
|
||||
percentiles = [ '%.2g' % x for x in percentiles ]
|
||||
percentiles = ' '.join(percentiles)
|
||||
return f'percentiles: [{percentiles}]'
|
||||
percentiles = ["%.2g" % x for x in percentiles]
|
||||
percentiles = " ".join(percentiles)
|
||||
return f"percentiles: [{percentiles}]"
|
||||
else:
|
||||
stats = stats.tolist()
|
||||
stats = [ '%.2g' % x for x in stats ]
|
||||
stats = '[' + ' '.join(stats) + ']'
|
||||
stats = ["%.2g" % x for x in stats]
|
||||
stats = "[" + " ".join(stats) + "]"
|
||||
return stats
|
||||
|
||||
|
||||
def print_diagnostics_for_dim(
|
||||
name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions
|
||||
):
|
||||
"""
|
||||
This function prints diagnostics for a dimension of a tensor.
|
||||
|
||||
def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
|
||||
options: TensorDiagnosticOptions):
|
||||
Args:
|
||||
name: the tensor name
|
||||
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
|
||||
tensors: list of cached tensors to get the stats
|
||||
options: options object
|
||||
"""
|
||||
|
||||
for stats_type in options.stats_types():
|
||||
for stats_type in ["mean-abs", "pos-ratio"]:
|
||||
# stats_type will be "mean-abs" or "pos-ratio".
|
||||
sizes = [ x.shape[dim] for x in tensors ]
|
||||
sizes_same = all([ x == sizes[0] for x in sizes ])
|
||||
s = get_diagnostics_for_dim(dim, tensors,
|
||||
options, sizes_same,
|
||||
stats_type)
|
||||
sizes = [x.shape[dim] for x in tensors]
|
||||
sizes_same = all([x == sizes[0] for x in sizes])
|
||||
s = get_diagnostics_for_dim(
|
||||
dim, tensors, options, sizes_same, stats_type
|
||||
)
|
||||
|
||||
min_size = min(sizes)
|
||||
max_size = max(sizes)
|
||||
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
|
||||
# stats_type will be "mean-abs" or "pos-ratio".
|
||||
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
|
||||
|
||||
|
||||
class TensorDiagnostic(object):
|
||||
"""
|
||||
This class is not directly used by the user, it is responsible for collecting
|
||||
diagnostics for a single parameter tensor of a torch.Module.
|
||||
This class is not directly used by the user, it is responsible for
|
||||
collecting diagnostics for a single parameter tensor of a torch.nn.Module.
|
||||
|
||||
Attributes:
|
||||
opts: options object.
|
||||
name: tensor name.
|
||||
saved_tensors: list of cached tensors.
|
||||
"""
|
||||
def __init__(self,
|
||||
opts: TensorDiagnosticOptions,
|
||||
name: str):
|
||||
|
||||
def __init__(self, opts: TensorDiagnosticOptions, name: str):
|
||||
self.name = name
|
||||
self.opts = opts
|
||||
self.saved_tensors = []
|
||||
|
||||
def accumulate(self, x):
|
||||
"""Accumulate tensors."""
|
||||
if isinstance(x, Tuple):
|
||||
x = x[0]
|
||||
if not isinstance(x, Tensor):
|
||||
return
|
||||
if x.device == torch.device('cpu'):
|
||||
if x.device == torch.device("cpu"):
|
||||
x = x.detach().clone()
|
||||
else:
|
||||
x = x.detach().to('cpu', non_blocking=True)
|
||||
x = x.detach().to("cpu", non_blocking=True)
|
||||
self.saved_tensors.append(x)
|
||||
l = len(self.saved_tensors)
|
||||
if l & (l - 1) == 0: # power of 2..
|
||||
num = len(self.saved_tensors)
|
||||
if num & (num - 1) == 0: # power of 2..
|
||||
self._limit_memory()
|
||||
|
||||
def _limit_memory(self):
|
||||
"""Only keep the newly cached tensors to limit memory."""
|
||||
if len(self.saved_tensors) > 1024:
|
||||
self.saved_tensors = self.saved_tensors[-1024:]
|
||||
return
|
||||
|
||||
tot_mem = 0.0
|
||||
for i in reversed(range(len(self.saved_tensors))):
|
||||
tot_mem += self.saved_tensors[i].numel() * self.saved_tensors[i].element_size()
|
||||
tot_mem += (
|
||||
self.saved_tensors[i].numel()
|
||||
* self.saved_tensors[i].element_size()
|
||||
)
|
||||
if tot_mem > self.opts.memory_limit:
|
||||
self.saved_tensors = self.saved_tensors[i:]
|
||||
return
|
||||
|
||||
def print_diagnostics(self):
|
||||
"""Print diagnostics for each dimension of the tensor."""
|
||||
if len(self.saved_tensors) == 0:
|
||||
print("{name}: no stats".format(name=self.name))
|
||||
return
|
||||
|
||||
if self.saved_tensors[0].ndim == 0:
|
||||
# ensure there is at least one dim.
|
||||
self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ]
|
||||
self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors]
|
||||
|
||||
ndim = self.saved_tensors[0].ndim
|
||||
for dim in range(ndim):
|
||||
print_diagnostics_for_dim(self.name, dim,
|
||||
self.saved_tensors,
|
||||
self.opts)
|
||||
print_diagnostics_for_dim(
|
||||
self.name, dim, self.saved_tensors, self.opts
|
||||
)
|
||||
|
||||
|
||||
class ModelDiagnostic(object):
|
||||
"""
|
||||
This class stores diagnostics for all tensors in the torch.nn.Module.
|
||||
|
||||
Attributes:
|
||||
diagnostics: a dictionary, whose keys are the tensors names and
|
||||
the values are corresponding TensorDiagnostic objects.
|
||||
opts: options object.
|
||||
"""
|
||||
|
||||
def __init__(self, opts: TensorDiagnosticOptions):
|
||||
self.diagnostics = dict()
|
||||
self.opts = opts
|
||||
@ -194,35 +225,51 @@ class ModelDiagnostic(object):
|
||||
return self.diagnostics[name]
|
||||
|
||||
def print_diagnostics(self):
|
||||
"""Print diagnostics for each tensor."""
|
||||
for k in sorted(self.diagnostics.keys()):
|
||||
self.diagnostics[k].print_diagnostics()
|
||||
|
||||
|
||||
def attach_diagnostics(
|
||||
model: nn.Module, opts: TensorDiagnosticOptions
|
||||
) -> ModelDiagnostic:
|
||||
"""
|
||||
Attach a ModelDiagnostic object to the model by
|
||||
1) registering forward hook and backward hook on each module, to accumulate
|
||||
its output tensors and gradient tensors, respectively;
|
||||
2) registering backward hook on each module parameter, to accumulate its
|
||||
values and gradients.
|
||||
|
||||
Args:
|
||||
model: the model to be analyzed.
|
||||
opts: options object.
|
||||
|
||||
Returns:
|
||||
The ModelDiagnostic object attached to the model.
|
||||
"""
|
||||
|
||||
def attach_diagnostics(model: nn.Module,
|
||||
opts: TensorDiagnosticOptions) -> ModelDiagnostic:
|
||||
ans = ModelDiagnostic(opts)
|
||||
for name, module in model.named_modules():
|
||||
if name == '':
|
||||
if name == "":
|
||||
name = "<top-level>"
|
||||
forward_diagnostic = TensorDiagnostic(opts, name + ".output")
|
||||
backward_diagnostic = TensorDiagnostic(opts, name + ".grad")
|
||||
|
||||
|
||||
# setting model_diagnostic=ans and n=name below, instead of trying to capture the variables,
|
||||
# ensures that we use the current values. (matters for name, since
|
||||
# the variable gets overwritten). these closures don't really capture
|
||||
# by value, only by "the final value the variable got in the function" :-(
|
||||
def forward_hook(_module, _input, _output,
|
||||
_model_diagnostic=ans, _name=name):
|
||||
# setting model_diagnostic=ans and n=name below, instead of trying to
|
||||
# capture the variables, ensures that we use the current values.
|
||||
# (matters for name, since the variable gets overwritten).
|
||||
# these closures don't really capture by value, only by
|
||||
# "the final value the variable got in the function" :-(
|
||||
def forward_hook(
|
||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||
):
|
||||
if isinstance(_output, Tensor):
|
||||
_model_diagnostic[f"{_name}.output"].accumulate(_output)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o)
|
||||
|
||||
def backward_hook(_module, _input, _output,
|
||||
_model_diagnostic=ans, _name=name):
|
||||
def backward_hook(
|
||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||
):
|
||||
if isinstance(_output, Tensor):
|
||||
_model_diagnostic[f"{_name}.grad"].accumulate(_output)
|
||||
elif isinstance(_output, tuple):
|
||||
@ -234,20 +281,19 @@ def attach_diagnostics(model: nn.Module,
|
||||
|
||||
for name, parameter in model.named_parameters():
|
||||
|
||||
def param_backward_hook(grad,
|
||||
_parameter=parameter,
|
||||
_model_diagnostic=ans,
|
||||
_name=name):
|
||||
def param_backward_hook(
|
||||
grad, _parameter=parameter, _model_diagnostic=ans, _name=name
|
||||
):
|
||||
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
|
||||
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
|
||||
|
||||
parameter.register_hook(param_backward_hook)
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
|
||||
def _test_tensor_diagnostic():
|
||||
opts = TensorDiagnosticOptions(2**20, True)
|
||||
opts = TensorDiagnosticOptions(2 ** 20)
|
||||
|
||||
diagnostic = TensorDiagnostic(opts, "foo")
|
||||
|
||||
@ -268,17 +314,5 @@ def _test_tensor_diagnostic():
|
||||
diagnostic.print_diagnostics()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
_test_tensor_diagnostic()
|
||||
|
||||
|
||||
def _test_func():
|
||||
ans = []
|
||||
for i in range(10):
|
||||
x = list()
|
||||
x.append(i)
|
||||
def func():
|
||||
return x
|
||||
ans.append(func)
|
||||
return ans
|
||||
|
Loading…
x
Reference in New Issue
Block a user