mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Update docs of arguments, and remove stats_types() function in TensorDiagnosticOptions object.
This commit is contained in:
parent
828a23daad
commit
87b4619f12
@ -521,7 +521,6 @@ def train_one_epoch(
|
|||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
@ -631,10 +630,11 @@ def run(rank, world_size, args):
|
|||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(2**22) # allow 4 megabytes per sub-module
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
|
2 ** 22
|
||||||
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
|
||||||
train_cuts = librispeech.train_clean_100_cuts()
|
train_cuts = librispeech.train_clean_100_cuts()
|
||||||
if params.full_libri:
|
if params.full_libri:
|
||||||
train_cuts += librispeech.train_clean_360_cuts()
|
train_cuts += librispeech.train_clean_360_cuts()
|
||||||
|
@ -1,94 +1,97 @@
|
|||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
from torch import nn
|
|
||||||
import math
|
|
||||||
import random
|
import random
|
||||||
from typing import Tuple, List
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
class TensorDiagnosticOptions(object):
|
class TensorDiagnosticOptions(object):
|
||||||
"""
|
"""
|
||||||
Options object for tensor diagnostics:
|
Options object for tensor diagnostics:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_limit: the maximum number of bytes per tensor (limits how many copies
|
memory_limit: the maximum number of bytes per tensor (limits how many
|
||||||
of the tensor we cache).
|
copies of the tensor we cache).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, memory_limit: int,
|
|
||||||
print_pos_ratio: bool = True):
|
def __init__(self, memory_limit: int):
|
||||||
self.memory_limit = memory_limit
|
self.memory_limit = memory_limit
|
||||||
self.print_pos_ratio = print_pos_ratio
|
|
||||||
|
|
||||||
def dim_is_summarized(self, size: int):
|
def dim_is_summarized(self, size: int):
|
||||||
return size > 10 and size != 31
|
return size > 10 and size != 31
|
||||||
|
|
||||||
def stats_types(self):
|
|
||||||
if self.print_pos_ratio:
|
|
||||||
return ["mean-abs", "pos-ratio"]
|
|
||||||
else:
|
|
||||||
return ["mean-abs"]
|
|
||||||
|
|
||||||
|
def get_sum_abs_stats(
|
||||||
|
x: Tensor, dim: int, stats_type: str
|
||||||
def get_sum_abs_stats(x: Tensor, dim: int,
|
) -> Tuple[Tensor, int]:
|
||||||
stats_type: str) -> Tuple[Tensor, int]:
|
|
||||||
"""
|
"""
|
||||||
Returns the sum-of-absolute-value of this Tensor, for each
|
Returns the sum-of-absolute-value of this Tensor, for each index into
|
||||||
index into the specified axis/dim of the tensor.
|
the specified axis/dim of the tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Tensor, tensor to be analyzed
|
x: Tensor, tensor to be analyzed
|
||||||
dim: dimension with 0 <= dim < x.ndim
|
dim: dimension with 0 <= dim < x.ndim
|
||||||
stats_type: either "mean-abs" in which case the stats represent the
|
stats_type: either "mean-abs" in which case the stats represent the
|
||||||
mean absolute value, or "pos-ratio" in which case the
|
mean absolute value, or "pos-ratio" in which case the stats represent
|
||||||
stats represent the proportion of positive values (actually:
|
the proportion of positive values (actually: the tensor is count of
|
||||||
the tensor is count of positive values, count is the count of
|
positive values, count is the count of all values).
|
||||||
all values).
|
|
||||||
Returns (sum_abs, count)
|
Returns (sum_abs, count):
|
||||||
where sum_abs is a Tensor of shape (x.shape[dim],), and the count
|
where sum_abs is a Tensor of shape (x.shape[dim],), and the count
|
||||||
is an integer saying how many items were counted in each element
|
is an integer saying how many items were counted in each element
|
||||||
of sum_abs.
|
of sum_abs.
|
||||||
"""
|
"""
|
||||||
if stats_type == "mean-abs":
|
if stats_type == "mean-abs":
|
||||||
x = x.abs()
|
x = x.abs()
|
||||||
else:
|
else:
|
||||||
assert stats_type == "pos-ratio"
|
assert stats_type == "pos-ratio"
|
||||||
x = (x > 0).to(dtype=torch.float)
|
x = (x > 0).to(dtype=torch.float)
|
||||||
|
|
||||||
orig_numel = x.numel()
|
orig_numel = x.numel()
|
||||||
sum_dims = [ d for d in range(x.ndim) if d != dim ]
|
sum_dims = [d for d in range(x.ndim) if d != dim]
|
||||||
x = torch.sum(x, dim=sum_dims)
|
x = torch.sum(x, dim=sum_dims)
|
||||||
count = orig_numel // x.numel()
|
count = orig_numel // x.numel()
|
||||||
x = x.flatten()
|
x = x.flatten()
|
||||||
|
|
||||||
return x, count
|
return x, count
|
||||||
|
|
||||||
def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
|
|
||||||
options: TensorDiagnosticOptions,
|
def get_diagnostics_for_dim(
|
||||||
sizes_same: bool,
|
dim: int,
|
||||||
stats_type: str):
|
tensors: List[Tensor],
|
||||||
|
options: TensorDiagnosticOptions,
|
||||||
|
sizes_same: bool,
|
||||||
|
stats_type: str,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
This function gets diagnostics for a dimension of a module.
|
This function gets diagnostics for a dimension of a module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
|
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
|
||||||
options: options object
|
tensors: list of cached tensors to get the stats
|
||||||
sizes_same: true if all the tensor sizes are the same on this dimension
|
options: options object
|
||||||
stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats
|
sizes_same: true if all the tensor sizes are the same on this dimension
|
||||||
we accumulate, mean-abs is mean absolute value, "pos-ratio"
|
stats_type: either "mean-abs" or "pos-ratio", dictates the type of
|
||||||
is proportion of positive to nonnegative values.
|
stats we accumulate, mean-abs is mean absolute value, "pos-ratio" is
|
||||||
|
proportion of positive to nonnegative values.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Diagnostic as a string, either percentiles or the actual values,
|
Diagnostic as a string, either percentiles or the actual values,
|
||||||
see the code.
|
see the code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# stats_and_counts is a list of pair (Tensor, int)
|
# stats_and_counts is a list of pair (Tensor, int)
|
||||||
stats_and_counts = [ get_sum_abs_stats(x, dim, stats_type) for x in tensors ]
|
stats_and_counts = [get_sum_abs_stats(x, dim, stats_type) for x in tensors]
|
||||||
stats = [ x[0] for x in stats_and_counts ]
|
stats = [x[0] for x in stats_and_counts]
|
||||||
counts = [ x[1] for x in stats_and_counts ]
|
counts = [x[1] for x in stats_and_counts]
|
||||||
if sizes_same:
|
if sizes_same:
|
||||||
stats = torch.stack(stats).sum(dim=0)
|
stats = torch.stack(stats).sum(dim=0)
|
||||||
count = sum(counts)
|
count = sum(counts)
|
||||||
stats = stats / count
|
stats = stats / count
|
||||||
else:
|
else:
|
||||||
stats = [ x[0] / x[1] for x in stats_and_counts ]
|
stats = [x[0] / x[1] for x in stats_and_counts]
|
||||||
stats = torch.cat(stats, dim=0)
|
stats = torch.cat(stats, dim=0)
|
||||||
|
|
||||||
# if `summarize` we print percentiles of the stats; else,
|
# if `summarize` we print percentiles of the stats; else,
|
||||||
# we print out individual elements.
|
# we print out individual elements.
|
||||||
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
|
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
|
||||||
@ -101,89 +104,117 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
|
|||||||
for i in range(num_percentiles + 1):
|
for i in range(num_percentiles + 1):
|
||||||
index = (i * (size - 1)) // num_percentiles
|
index = (i * (size - 1)) // num_percentiles
|
||||||
percentiles.append(stats[index].item())
|
percentiles.append(stats[index].item())
|
||||||
percentiles = [ '%.2g' % x for x in percentiles ]
|
percentiles = ["%.2g" % x for x in percentiles]
|
||||||
percentiles = ' '.join(percentiles)
|
percentiles = " ".join(percentiles)
|
||||||
return f'percentiles: [{percentiles}]'
|
return f"percentiles: [{percentiles}]"
|
||||||
else:
|
else:
|
||||||
stats = stats.tolist()
|
stats = stats.tolist()
|
||||||
stats = [ '%.2g' % x for x in stats ]
|
stats = ["%.2g" % x for x in stats]
|
||||||
stats = '[' + ' '.join(stats) + ']'
|
stats = "[" + " ".join(stats) + "]"
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def print_diagnostics_for_dim(
|
||||||
|
name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This function prints diagnostics for a dimension of a tensor.
|
||||||
|
|
||||||
def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
|
Args:
|
||||||
options: TensorDiagnosticOptions):
|
name: the tensor name
|
||||||
|
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
|
||||||
|
tensors: list of cached tensors to get the stats
|
||||||
|
options: options object
|
||||||
|
"""
|
||||||
|
|
||||||
for stats_type in options.stats_types():
|
for stats_type in ["mean-abs", "pos-ratio"]:
|
||||||
# stats_type will be "mean-abs" or "pos-ratio".
|
# stats_type will be "mean-abs" or "pos-ratio".
|
||||||
sizes = [ x.shape[dim] for x in tensors ]
|
sizes = [x.shape[dim] for x in tensors]
|
||||||
sizes_same = all([ x == sizes[0] for x in sizes ])
|
sizes_same = all([x == sizes[0] for x in sizes])
|
||||||
s = get_diagnostics_for_dim(dim, tensors,
|
s = get_diagnostics_for_dim(
|
||||||
options, sizes_same,
|
dim, tensors, options, sizes_same, stats_type
|
||||||
stats_type)
|
)
|
||||||
|
|
||||||
min_size = min(sizes)
|
min_size = min(sizes)
|
||||||
max_size = max(sizes)
|
max_size = max(sizes)
|
||||||
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
|
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
|
||||||
# stats_type will be "mean-abs" or "pos-ratio".
|
|
||||||
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
|
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
|
||||||
|
|
||||||
|
|
||||||
class TensorDiagnostic(object):
|
class TensorDiagnostic(object):
|
||||||
"""
|
"""
|
||||||
This class is not directly used by the user, it is responsible for collecting
|
This class is not directly used by the user, it is responsible for
|
||||||
diagnostics for a single parameter tensor of a torch.Module.
|
collecting diagnostics for a single parameter tensor of a torch.nn.Module.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
opts: options object.
|
||||||
|
name: tensor name.
|
||||||
|
saved_tensors: list of cached tensors.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
|
||||||
opts: TensorDiagnosticOptions,
|
def __init__(self, opts: TensorDiagnosticOptions, name: str):
|
||||||
name: str):
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.opts = opts
|
self.opts = opts
|
||||||
self.saved_tensors = []
|
self.saved_tensors = []
|
||||||
|
|
||||||
def accumulate(self, x):
|
def accumulate(self, x):
|
||||||
|
"""Accumulate tensors."""
|
||||||
if isinstance(x, Tuple):
|
if isinstance(x, Tuple):
|
||||||
x = x[0]
|
x = x[0]
|
||||||
if not isinstance(x, Tensor):
|
if not isinstance(x, Tensor):
|
||||||
return
|
return
|
||||||
if x.device == torch.device('cpu'):
|
if x.device == torch.device("cpu"):
|
||||||
x = x.detach().clone()
|
x = x.detach().clone()
|
||||||
else:
|
else:
|
||||||
x = x.detach().to('cpu', non_blocking=True)
|
x = x.detach().to("cpu", non_blocking=True)
|
||||||
self.saved_tensors.append(x)
|
self.saved_tensors.append(x)
|
||||||
l = len(self.saved_tensors)
|
num = len(self.saved_tensors)
|
||||||
if l & (l - 1) == 0: # power of 2..
|
if num & (num - 1) == 0: # power of 2..
|
||||||
self._limit_memory()
|
self._limit_memory()
|
||||||
|
|
||||||
def _limit_memory(self):
|
def _limit_memory(self):
|
||||||
|
"""Only keep the newly cached tensors to limit memory."""
|
||||||
if len(self.saved_tensors) > 1024:
|
if len(self.saved_tensors) > 1024:
|
||||||
self.saved_tensors = self.saved_tensors[-1024:]
|
self.saved_tensors = self.saved_tensors[-1024:]
|
||||||
return
|
return
|
||||||
|
|
||||||
tot_mem = 0.0
|
tot_mem = 0.0
|
||||||
for i in reversed(range(len(self.saved_tensors))):
|
for i in reversed(range(len(self.saved_tensors))):
|
||||||
tot_mem += self.saved_tensors[i].numel() * self.saved_tensors[i].element_size()
|
tot_mem += (
|
||||||
|
self.saved_tensors[i].numel()
|
||||||
|
* self.saved_tensors[i].element_size()
|
||||||
|
)
|
||||||
if tot_mem > self.opts.memory_limit:
|
if tot_mem > self.opts.memory_limit:
|
||||||
self.saved_tensors = self.saved_tensors[i:]
|
self.saved_tensors = self.saved_tensors[i:]
|
||||||
return
|
return
|
||||||
|
|
||||||
def print_diagnostics(self):
|
def print_diagnostics(self):
|
||||||
|
"""Print diagnostics for each dimension of the tensor."""
|
||||||
if len(self.saved_tensors) == 0:
|
if len(self.saved_tensors) == 0:
|
||||||
print("{name}: no stats".format(name=self.name))
|
print("{name}: no stats".format(name=self.name))
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.saved_tensors[0].ndim == 0:
|
if self.saved_tensors[0].ndim == 0:
|
||||||
# ensure there is at least one dim.
|
# ensure there is at least one dim.
|
||||||
self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ]
|
self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors]
|
||||||
|
|
||||||
ndim = self.saved_tensors[0].ndim
|
ndim = self.saved_tensors[0].ndim
|
||||||
for dim in range(ndim):
|
for dim in range(ndim):
|
||||||
print_diagnostics_for_dim(self.name, dim,
|
print_diagnostics_for_dim(
|
||||||
self.saved_tensors,
|
self.name, dim, self.saved_tensors, self.opts
|
||||||
self.opts)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelDiagnostic(object):
|
class ModelDiagnostic(object):
|
||||||
|
"""
|
||||||
|
This class stores diagnostics for all tensors in the torch.nn.Module.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
diagnostics: a dictionary, whose keys are the tensors names and
|
||||||
|
the values are corresponding TensorDiagnostic objects.
|
||||||
|
opts: options object.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, opts: TensorDiagnosticOptions):
|
def __init__(self, opts: TensorDiagnosticOptions):
|
||||||
self.diagnostics = dict()
|
self.diagnostics = dict()
|
||||||
self.opts = opts
|
self.opts = opts
|
||||||
@ -194,35 +225,51 @@ class ModelDiagnostic(object):
|
|||||||
return self.diagnostics[name]
|
return self.diagnostics[name]
|
||||||
|
|
||||||
def print_diagnostics(self):
|
def print_diagnostics(self):
|
||||||
|
"""Print diagnostics for each tensor."""
|
||||||
for k in sorted(self.diagnostics.keys()):
|
for k in sorted(self.diagnostics.keys()):
|
||||||
self.diagnostics[k].print_diagnostics()
|
self.diagnostics[k].print_diagnostics()
|
||||||
|
|
||||||
|
|
||||||
|
def attach_diagnostics(
|
||||||
|
model: nn.Module, opts: TensorDiagnosticOptions
|
||||||
|
) -> ModelDiagnostic:
|
||||||
|
"""
|
||||||
|
Attach a ModelDiagnostic object to the model by
|
||||||
|
1) registering forward hook and backward hook on each module, to accumulate
|
||||||
|
its output tensors and gradient tensors, respectively;
|
||||||
|
2) registering backward hook on each module parameter, to accumulate its
|
||||||
|
values and gradients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: the model to be analyzed.
|
||||||
|
opts: options object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The ModelDiagnostic object attached to the model.
|
||||||
|
"""
|
||||||
|
|
||||||
def attach_diagnostics(model: nn.Module,
|
|
||||||
opts: TensorDiagnosticOptions) -> ModelDiagnostic:
|
|
||||||
ans = ModelDiagnostic(opts)
|
ans = ModelDiagnostic(opts)
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if name == '':
|
if name == "":
|
||||||
name = "<top-level>"
|
name = "<top-level>"
|
||||||
forward_diagnostic = TensorDiagnostic(opts, name + ".output")
|
|
||||||
backward_diagnostic = TensorDiagnostic(opts, name + ".grad")
|
|
||||||
|
|
||||||
|
# setting model_diagnostic=ans and n=name below, instead of trying to
|
||||||
# setting model_diagnostic=ans and n=name below, instead of trying to capture the variables,
|
# capture the variables, ensures that we use the current values.
|
||||||
# ensures that we use the current values. (matters for name, since
|
# (matters for name, since the variable gets overwritten).
|
||||||
# the variable gets overwritten). these closures don't really capture
|
# these closures don't really capture by value, only by
|
||||||
# by value, only by "the final value the variable got in the function" :-(
|
# "the final value the variable got in the function" :-(
|
||||||
def forward_hook(_module, _input, _output,
|
def forward_hook(
|
||||||
_model_diagnostic=ans, _name=name):
|
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||||
|
):
|
||||||
if isinstance(_output, Tensor):
|
if isinstance(_output, Tensor):
|
||||||
_model_diagnostic[f"{_name}.output"].accumulate(_output)
|
_model_diagnostic[f"{_name}.output"].accumulate(_output)
|
||||||
elif isinstance(_output, tuple):
|
elif isinstance(_output, tuple):
|
||||||
for i, o in enumerate(_output):
|
for i, o in enumerate(_output):
|
||||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o)
|
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o)
|
||||||
|
|
||||||
def backward_hook(_module, _input, _output,
|
def backward_hook(
|
||||||
_model_diagnostic=ans, _name=name):
|
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||||
|
):
|
||||||
if isinstance(_output, Tensor):
|
if isinstance(_output, Tensor):
|
||||||
_model_diagnostic[f"{_name}.grad"].accumulate(_output)
|
_model_diagnostic[f"{_name}.grad"].accumulate(_output)
|
||||||
elif isinstance(_output, tuple):
|
elif isinstance(_output, tuple):
|
||||||
@ -234,20 +281,19 @@ def attach_diagnostics(model: nn.Module,
|
|||||||
|
|
||||||
for name, parameter in model.named_parameters():
|
for name, parameter in model.named_parameters():
|
||||||
|
|
||||||
def param_backward_hook(grad,
|
def param_backward_hook(
|
||||||
_parameter=parameter,
|
grad, _parameter=parameter, _model_diagnostic=ans, _name=name
|
||||||
_model_diagnostic=ans,
|
):
|
||||||
_name=name):
|
|
||||||
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
|
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
|
||||||
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
|
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
|
||||||
|
|
||||||
parameter.register_hook(param_backward_hook)
|
parameter.register_hook(param_backward_hook)
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_tensor_diagnostic():
|
def _test_tensor_diagnostic():
|
||||||
opts = TensorDiagnosticOptions(2**20, True)
|
opts = TensorDiagnosticOptions(2 ** 20)
|
||||||
|
|
||||||
diagnostic = TensorDiagnostic(opts, "foo")
|
diagnostic = TensorDiagnostic(opts, "foo")
|
||||||
|
|
||||||
@ -268,17 +314,5 @@ def _test_tensor_diagnostic():
|
|||||||
diagnostic.print_diagnostics()
|
diagnostic.print_diagnostics()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
if __name__ == '__main__':
|
|
||||||
_test_tensor_diagnostic()
|
_test_tensor_diagnostic()
|
||||||
|
|
||||||
|
|
||||||
def _test_func():
|
|
||||||
ans = []
|
|
||||||
for i in range(10):
|
|
||||||
x = list()
|
|
||||||
x.append(i)
|
|
||||||
def func():
|
|
||||||
return x
|
|
||||||
ans.append(func)
|
|
||||||
return ans
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user