Adding diagnostics code...

This commit is contained in:
Daniel Povey 2022-02-27 13:44:43 +08:00
parent 2af1b3af98
commit 581786a6d3
2 changed files with 313 additions and 11 deletions

View File

@ -0,0 +1,284 @@
import torch
from torch import Tensor
from torch import nn
import math
import random
from typing import Tuple, List
class TensorDiagnosticOptions(object):
"""
Options object for tensor diagnostics:
Args:
memory_limit: the maximum number of bytes per tensor (limits how many copies
of the tensor we cache).
"""
def __init__(self, memory_limit: int,
print_pos_ratio: bool = True):
self.memory_limit = memory_limit
self.print_pos_ratio = print_pos_ratio
def dim_is_summarized(self, size: int):
return size > 10 and size != 31
def stats_types(self):
if self.print_pos_ratio:
return ["mean-abs", "pos-ratio"]
else:
return ["mean-abs"]
def get_sum_abs_stats(x: Tensor, dim: int,
stats_type: str) -> Tuple[Tensor, int]:
"""
Returns the sum-of-absolute-value of this Tensor, for each
index into the specified axis/dim of the tensor.
Args:
x: Tensor, tensor to be analyzed
dim: dimension with 0 <= dim < x.ndim
stats_type: either "mean-abs" in which case the stats represent the
mean absolute value, or "pos-ratio" in which case the
stats represent the proportion of positive values (actually:
the tensor is count of positive values, count is the count of
all values).
Returns (sum_abs, count)
where sum_abs is a Tensor of shape (x.shape[dim],), and the count
is an integer saying how many items were counted in each element
of sum_abs.
"""
if stats_type == "mean-abs":
x = x.abs()
else:
assert stats_type == "pos-ratio"
x = (x > 0).to(dtype=torch.float)
orig_numel = x.numel()
sum_dims = [ d for d in range(x.ndim) if d != dim ]
x = torch.sum(x, dim=sum_dims)
count = orig_numel // x.numel()
x = x.flatten()
return x, count
def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
options: TensorDiagnosticOptions,
sizes_same: bool,
stats_type: str):
"""
This function gets diagnostics for a dimension of a module.
Args:
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
options: options object
sizes_same: true if all the tensor sizes are the same on this dimension
stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats
we accumulate, mean-abs is mean absolute value, "pos-ratio"
is proportion of positive to nonnegative values.
Returns:
Diagnostic as a string, either percentiles or the actual values,
see the code.
"""
# stats_and_counts is a list of pair (Tensor, int)
stats_and_counts = [ get_sum_abs_stats(x, dim, stats_type) for x in tensors ]
stats = [ x[0] for x in stats_and_counts ]
counts = [ x[1] for x in stats_and_counts ]
if sizes_same:
stats = torch.stack(stats).sum(dim=0)
count = sum(counts)
stats = stats / count
else:
stats = [ x[0] / x[1] for x in stats_and_counts ]
stats = torch.cat(stats, dim=0)
# if `summarize` we print percentiles of the stats; else,
# we print out individual elements.
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
if summarize:
# print out percentiles.
stats = stats.sort()[0]
num_percentiles = 10
size = stats.numel()
percentiles = []
for i in range(num_percentiles + 1):
index = (i * (size - 1)) // num_percentiles
percentiles.append(stats[index].item())
percentiles = [ '%.2g' % x for x in percentiles ]
percentiles = ' '.join(percentiles)
return f'percentiles: [{percentiles}]'
else:
stats = stats.tolist()
stats = [ '%.2g' % x for x in stats ]
stats = '[' + ' '.join(stats) + ']'
return stats
def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
options: TensorDiagnosticOptions):
for stats_type in options.stats_types():
# stats_type will be "mean-abs" or "pos-ratio".
sizes = [ x.shape[dim] for x in tensors ]
sizes_same = all([ x == sizes[0] for x in sizes ])
s = get_diagnostics_for_dim(dim, tensors,
options, sizes_same,
stats_type)
min_size = min(sizes)
max_size = max(sizes)
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
# stats_type will be "mean-abs" or "pos-ratio".
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
class TensorDiagnostic(object):
"""
This class is not directly used by the user, it is responsible for collecting
diagnostics for a single parameter tensor of a torch.Module.
"""
def __init__(self,
opts: TensorDiagnosticOptions,
name: str):
self.name = name
self.opts = opts
self.saved_tensors = []
def accumulate(self, x):
if isinstance(x, Tuple):
x = x[0]
if not isinstance(x, Tensor):
return
if x.device == torch.device('cpu'):
x = x.detach().clone()
else:
x = x.detach().to('cpu', non_blocking=True)
self.saved_tensors.append(x)
l = len(self.saved_tensors)
if l & (l - 1) == 0: # power of 2..
self._limit_memory()
def _limit_memory(self):
if len(self.saved_tensors) > 1024:
self.saved_tensors = self.saved_tensors[-1024:]
return
tot_mem = 0.0
for i in reversed(range(len(self.saved_tensors))):
tot_mem += self.saved_tensors[i].numel() * self.saved_tensors[i].element_size()
if tot_mem > self.opts.memory_limit:
self.saved_tensors = self.saved_tensors[i:]
return
def print_diagnostics(self):
if len(self.saved_tensors) == 0:
print("{name}: no stats".format(name=self.name))
return
if self.saved_tensors[0].ndim == 0:
# ensure there is at least one dim.
self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ]
ndim = self.saved_tensors[0].ndim
for dim in range(ndim):
print_diagnostics_for_dim(self.name, dim,
self.saved_tensors,
self.opts)
class ModelDiagnostic(object):
def __init__(self, opts: TensorDiagnosticOptions):
self.diagnostics = dict()
self.opts = opts
def __getitem__(self, name: str):
if name not in self.diagnostics:
self.diagnostics[name] = TensorDiagnostic(self.opts, name)
return self.diagnostics[name]
def print_diagnostics(self):
for k in sorted(self.diagnostics.keys()):
self.diagnostics[k].print_diagnostics()
def attach_diagnostics(model: nn.Module,
opts: TensorDiagnosticOptions) -> ModelDiagnostic:
ans = ModelDiagnostic(opts)
for name, module in model.named_modules():
if name == '':
name = "<top-level>"
forward_diagnostic = TensorDiagnostic(opts, name + ".output")
backward_diagnostic = TensorDiagnostic(opts, name + ".grad")
# setting model_diagnostic=ans and n=name below, instead of trying to capture the variables,
# ensures that we use the current values. (matters for name, since
# the variable gets overwritten). these closures don't really capture
# by value, only by "the final value the variable got in the function" :-(
def forward_hook(_module, _input, _output,
_model_diagnostic=ans, _name=name):
if isinstance(_output, Tensor):
_model_diagnostic[f"{_name}.output"].accumulate(_output)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o)
def backward_hook(_module, _input, _output,
_model_diagnostic=ans, _name=name):
if isinstance(_output, Tensor):
_model_diagnostic[f"{_name}.grad"].accumulate(_output)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o)
module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)
for name, parameter in model.named_parameters():
def param_backward_hook(grad,
_parameter=parameter,
_model_diagnostic=ans,
_name=name):
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
parameter.register_hook(param_backward_hook)
return ans
def _test_tensor_diagnostic():
opts = TensorDiagnosticOptions(2**20, True)
diagnostic = TensorDiagnostic(opts, "foo")
for _ in range(10):
diagnostic.accumulate(torch.randn(50, 100) * 10.0)
diagnostic.print_diagnostics()
model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
diagnostic = attach_diagnostics(model, opts)
for _ in range(10):
T = random.randint(200, 300)
x = torch.randn(T, 100)
y = model(x)
y.sum().backward()
diagnostic.print_diagnostics()
if __name__ == '__main__':
_test_tensor_diagnostic()
def _test_func():
ans = []
for i in range(10):
x = list()
x.append(i)
def func():
return x
ans.append(func)
return ans

View File

@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse
import logging
import diagnostics # ./diagnostics.py
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
@ -109,7 +110,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/exp-100h-specaugmod_p0.9_0.15_fix",
default="transducer_stateless/specaugmod_baseline",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -138,6 +139,13 @@ def get_parser():
"2 means tri-gram",
)
parser.add_argument(
"--print-diagnostics",
type=str2bool,
default=False,
help="Accumulate stats on activations, print them and exit.",
)
return parser
@ -487,6 +495,9 @@ def train_one_epoch(
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if params.print_diagnostics and batch_idx == 5:
return
if batch_idx % params.log_interval == 0:
logging.info(
@ -494,9 +505,6 @@ def train_one_epoch(
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
@ -599,6 +607,11 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(2**22) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
@ -626,6 +639,7 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
@ -660,6 +674,10 @@ def run(rank, world_size, args):
world_size=world_size,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
save_checkpoint(
params=params,
model=model,