Add diagnostics (#230)

* Adding diagnostics code...

* Move diagnostics code from local dir to the shared icefall dir

* Remove the diagnostics code in the local dir

* Update docs of arguments, and remove stats_types() function in TensorDiagnosticOptions object.

* Update docs of arguments.

* Add copyright information.

* Corrected the time in copyright information.

Co-authored-by: Daniel Povey <dpovey@gmail.com>
This commit is contained in:
yaozengwei 2022-03-04 15:38:23 +08:00 committed by GitHub
parent 2f0fbf430c
commit ad62981765
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 374 additions and 10 deletions

View File

@ -56,6 +56,7 @@ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
@ -156,6 +157,13 @@ def get_parser():
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--print-diagnostics",
type=str2bool,
default=False,
help="Accumulate stats on activations, print them and exit.",
)
return parser
@ -510,6 +518,8 @@ def train_one_epoch(
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if params.print_diagnostics and batch_idx == 5:
return
if batch_idx % params.log_interval == 0:
logging.info(
@ -517,9 +527,6 @@ def train_one_epoch(
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
@ -622,6 +629,12 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
@ -649,13 +662,14 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
@ -684,6 +698,10 @@ def run(rank, world_size, args):
world_size=world_size,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
save_checkpoint(
params=params,
model=model,

346
icefall/diagnostics.py Normal file
View File

@ -0,0 +1,346 @@
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey
# Zengwei Yao)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import List, Tuple
import torch
from torch import Tensor, nn
class TensorDiagnosticOptions(object):
"""Options object for tensor diagnostics:
Args:
memory_limit:
The maximum number of bytes per tensor (limits how many copies
of the tensor we cache).
"""
def __init__(self, memory_limit: int):
self.memory_limit = memory_limit
def dim_is_summarized(self, size: int):
return size > 10 and size != 31
def get_sum_abs_stats(
x: Tensor, dim: int, stats_type: str
) -> Tuple[Tensor, int]:
"""Returns the sum-of-absolute-value of this Tensor, for each index into
the specified axis/dim of the tensor.
Args:
x:
Tensor, tensor to be analyzed
dim:
Dimension with 0 <= dim < x.ndim
stats_type:
Either "mean-abs" in which case the stats represent the mean absolute
value, or "pos-ratio" in which case the stats represent the proportion
of positive values (actually: the tensor is count of positive values,
count is the count of all values).
Returns:
(sum_abs, count) where sum_abs is a Tensor of shape (x.shape[dim],),
and the count is an integer saying how many items were counted in
each element of sum_abs.
"""
if stats_type == "mean-abs":
x = x.abs()
else:
assert stats_type == "pos-ratio"
x = (x > 0).to(dtype=torch.float)
orig_numel = x.numel()
sum_dims = [d for d in range(x.ndim) if d != dim]
x = torch.sum(x, dim=sum_dims)
count = orig_numel // x.numel()
x = x.flatten()
return x, count
def get_diagnostics_for_dim(
dim: int,
tensors: List[Tensor],
options: TensorDiagnosticOptions,
sizes_same: bool,
stats_type: str,
) -> str:
"""This function gets diagnostics for a dimension of a module.
Args:
dim:
The dimension to analyze, with 0 <= dim < tensors[0].ndim
tensors:
List of cached tensors to get the stats
options:
Options object
sizes_same:
True if all the tensor sizes are the same on this dimension
stats_type: either "mean-abs" or "pos-ratio", dictates the type of
stats we accumulate, mean-abs is mean absolute value, "pos-ratio" is
proportion of positive to nonnegative values.
Returns:
Diagnostic as a string, either percentiles or the actual values,
see the code.
"""
# stats_and_counts is a list of pair (Tensor, int)
stats_and_counts = [get_sum_abs_stats(x, dim, stats_type) for x in tensors]
stats = [x[0] for x in stats_and_counts]
counts = [x[1] for x in stats_and_counts]
if sizes_same:
stats = torch.stack(stats).sum(dim=0)
count = sum(counts)
stats = stats / count
else:
stats = [x[0] / x[1] for x in stats_and_counts]
stats = torch.cat(stats, dim=0)
# If `summarize` we print percentiles of the stats;
# else, we print out individual elements.
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
if summarize:
# Print out percentiles.
stats = stats.sort()[0]
num_percentiles = 10
size = stats.numel()
percentiles = []
for i in range(num_percentiles + 1):
index = (i * (size - 1)) // num_percentiles
percentiles.append(stats[index].item())
percentiles = ["%.2g" % x for x in percentiles]
percentiles = " ".join(percentiles)
return f"percentiles: [{percentiles}]"
else:
stats = stats.tolist()
stats = ["%.2g" % x for x in stats]
stats = "[" + " ".join(stats) + "]"
return stats
def print_diagnostics_for_dim(
name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions
):
"""This function prints diagnostics for a dimension of a tensor.
Args:
name:
The tensor name.
dim:
The dimension to analyze, with 0 <= dim < tensors[0].ndim.
tensors:
List of cached tensors to get the stats.
options:
Options object.
"""
for stats_type in ["mean-abs", "pos-ratio"]:
# stats_type will be "mean-abs" or "pos-ratio".
sizes = [x.shape[dim] for x in tensors]
sizes_same = all([x == sizes[0] for x in sizes])
s = get_diagnostics_for_dim(
dim, tensors, options, sizes_same, stats_type
)
min_size = min(sizes)
max_size = max(sizes)
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
class TensorDiagnostic(object):
"""This class is not directly used by the user, it is responsible for
collecting diagnostics for a single parameter tensor of a torch.nn.Module.
Args:
opts:
Options object.
name:
The tensor name.
"""
def __init__(self, opts: TensorDiagnosticOptions, name: str):
self.name = name
self.opts = opts
# A list to cache the tensors.
self.saved_tensors = []
def accumulate(self, x):
"""Accumulate tensors."""
if isinstance(x, Tuple):
x = x[0]
if not isinstance(x, Tensor):
return
if x.device == torch.device("cpu"):
x = x.detach().clone()
else:
x = x.detach().to("cpu", non_blocking=True)
self.saved_tensors.append(x)
num = len(self.saved_tensors)
if num & (num - 1) == 0: # power of 2..
self._limit_memory()
def _limit_memory(self):
"""Only keep the newly cached tensors to limit memory."""
if len(self.saved_tensors) > 1024:
self.saved_tensors = self.saved_tensors[-1024:]
return
tot_mem = 0.0
for i in reversed(range(len(self.saved_tensors))):
tot_mem += (
self.saved_tensors[i].numel()
* self.saved_tensors[i].element_size()
)
if tot_mem > self.opts.memory_limit:
self.saved_tensors = self.saved_tensors[i:]
return
def print_diagnostics(self):
"""Print diagnostics for each dimension of the tensor."""
if len(self.saved_tensors) == 0:
print("{name}: no stats".format(name=self.name))
return
if self.saved_tensors[0].ndim == 0:
# Ensure there is at least one dim.
self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors]
ndim = self.saved_tensors[0].ndim
for dim in range(ndim):
print_diagnostics_for_dim(
self.name, dim, self.saved_tensors, self.opts
)
class ModelDiagnostic(object):
"""This class stores diagnostics for all tensors in the torch.nn.Module.
Args:
opts:
Options object.
"""
def __init__(self, opts: TensorDiagnosticOptions):
# In this dictionary, the keys are tensors names and the values
# are corresponding TensorDiagnostic objects.
self.diagnostics = dict()
self.opts = opts
def __getitem__(self, name: str):
if name not in self.diagnostics:
self.diagnostics[name] = TensorDiagnostic(self.opts, name)
return self.diagnostics[name]
def print_diagnostics(self):
"""Print diagnostics for each tensor."""
for k in sorted(self.diagnostics.keys()):
self.diagnostics[k].print_diagnostics()
def attach_diagnostics(
model: nn.Module, opts: TensorDiagnosticOptions
) -> ModelDiagnostic:
"""Attach a ModelDiagnostic object to the model by
1) registering forward hook and backward hook on each module, to accumulate
its output tensors and gradient tensors, respectively;
2) registering backward hook on each module parameter, to accumulate its
values and gradients.
Args:
model:
the model to be analyzed.
opts:
Options object.
Returns:
The ModelDiagnostic object attached to the model.
"""
ans = ModelDiagnostic(opts)
for name, module in model.named_modules():
if name == "":
name = "<top-level>"
# Setting model_diagnostic=ans and n=name below, instead of trying to
# capture the variables, ensures that we use the current values.
# (matters for name, since the variable gets overwritten).
# These closures don't really capture by value, only by
# "the final value the variable got in the function" :-(
def forward_hook(
_module, _input, _output, _model_diagnostic=ans, _name=name
):
if isinstance(_output, Tensor):
_model_diagnostic[f"{_name}.output"].accumulate(_output)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o)
def backward_hook(
_module, _input, _output, _model_diagnostic=ans, _name=name
):
if isinstance(_output, Tensor):
_model_diagnostic[f"{_name}.grad"].accumulate(_output)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o)
module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)
for name, parameter in model.named_parameters():
def param_backward_hook(
grad, _parameter=parameter, _model_diagnostic=ans, _name=name
):
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
parameter.register_hook(param_backward_hook)
return ans
def _test_tensor_diagnostic():
opts = TensorDiagnosticOptions(2 ** 20)
diagnostic = TensorDiagnostic(opts, "foo")
for _ in range(10):
diagnostic.accumulate(torch.randn(50, 100) * 10.0)
diagnostic.print_diagnostics()
model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
diagnostic = attach_diagnostics(model, opts)
for _ in range(10):
T = random.randint(200, 300)
x = torch.randn(T, 100)
y = model(x)
y.sum().backward()
diagnostic.print_diagnostics()
if __name__ == "__main__":
_test_tensor_diagnostic()