Adding diagnostics code...

This commit is contained in:
Daniel Povey 2022-02-27 13:44:43 +08:00
parent 2af1b3af98
commit 581786a6d3
2 changed files with 313 additions and 11 deletions

View File

@ -0,0 +1,284 @@
import torch
from torch import Tensor
from torch import nn
import math
import random
from typing import Tuple, List
class TensorDiagnosticOptions(object):
"""
Options object for tensor diagnostics:
Args:
memory_limit: the maximum number of bytes per tensor (limits how many copies
of the tensor we cache).
"""
def __init__(self, memory_limit: int,
print_pos_ratio: bool = True):
self.memory_limit = memory_limit
self.print_pos_ratio = print_pos_ratio
def dim_is_summarized(self, size: int):
return size > 10 and size != 31
def stats_types(self):
if self.print_pos_ratio:
return ["mean-abs", "pos-ratio"]
else:
return ["mean-abs"]
def get_sum_abs_stats(x: Tensor, dim: int,
stats_type: str) -> Tuple[Tensor, int]:
"""
Returns the sum-of-absolute-value of this Tensor, for each
index into the specified axis/dim of the tensor.
Args:
x: Tensor, tensor to be analyzed
dim: dimension with 0 <= dim < x.ndim
stats_type: either "mean-abs" in which case the stats represent the
mean absolute value, or "pos-ratio" in which case the
stats represent the proportion of positive values (actually:
the tensor is count of positive values, count is the count of
all values).
Returns (sum_abs, count)
where sum_abs is a Tensor of shape (x.shape[dim],), and the count
is an integer saying how many items were counted in each element
of sum_abs.
"""
if stats_type == "mean-abs":
x = x.abs()
else:
assert stats_type == "pos-ratio"
x = (x > 0).to(dtype=torch.float)
orig_numel = x.numel()
sum_dims = [ d for d in range(x.ndim) if d != dim ]
x = torch.sum(x, dim=sum_dims)
count = orig_numel // x.numel()
x = x.flatten()
return x, count
def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
options: TensorDiagnosticOptions,
sizes_same: bool,
stats_type: str):
"""
This function gets diagnostics for a dimension of a module.
Args:
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
options: options object
sizes_same: true if all the tensor sizes are the same on this dimension
stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats
we accumulate, mean-abs is mean absolute value, "pos-ratio"
is proportion of positive to nonnegative values.
Returns:
Diagnostic as a string, either percentiles or the actual values,
see the code.
"""
# stats_and_counts is a list of pair (Tensor, int)
stats_and_counts = [ get_sum_abs_stats(x, dim, stats_type) for x in tensors ]
stats = [ x[0] for x in stats_and_counts ]
counts = [ x[1] for x in stats_and_counts ]
if sizes_same:
stats = torch.stack(stats).sum(dim=0)
count = sum(counts)
stats = stats / count
else:
stats = [ x[0] / x[1] for x in stats_and_counts ]
stats = torch.cat(stats, dim=0)
# if `summarize` we print percentiles of the stats; else,
# we print out individual elements.
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
if summarize:
# print out percentiles.
stats = stats.sort()[0]
num_percentiles = 10
size = stats.numel()
percentiles = []
for i in range(num_percentiles + 1):
index = (i * (size - 1)) // num_percentiles
percentiles.append(stats[index].item())
percentiles = [ '%.2g' % x for x in percentiles ]
percentiles = ' '.join(percentiles)
return f'percentiles: [{percentiles}]'
else:
stats = stats.tolist()
stats = [ '%.2g' % x for x in stats ]
stats = '[' + ' '.join(stats) + ']'
return stats
def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
options: TensorDiagnosticOptions):
for stats_type in options.stats_types():
# stats_type will be "mean-abs" or "pos-ratio".
sizes = [ x.shape[dim] for x in tensors ]
sizes_same = all([ x == sizes[0] for x in sizes ])
s = get_diagnostics_for_dim(dim, tensors,
options, sizes_same,
stats_type)
min_size = min(sizes)
max_size = max(sizes)
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
# stats_type will be "mean-abs" or "pos-ratio".
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
class TensorDiagnostic(object):
"""
This class is not directly used by the user, it is responsible for collecting
diagnostics for a single parameter tensor of a torch.Module.
"""
def __init__(self,
opts: TensorDiagnosticOptions,
name: str):
self.name = name
self.opts = opts
self.saved_tensors = []
def accumulate(self, x):
if isinstance(x, Tuple):
x = x[0]
if not isinstance(x, Tensor):
return
if x.device == torch.device('cpu'):
x = x.detach().clone()
else:
x = x.detach().to('cpu', non_blocking=True)
self.saved_tensors.append(x)
l = len(self.saved_tensors)
if l & (l - 1) == 0: # power of 2..
self._limit_memory()
def _limit_memory(self):
if len(self.saved_tensors) > 1024:
self.saved_tensors = self.saved_tensors[-1024:]
return
tot_mem = 0.0
for i in reversed(range(len(self.saved_tensors))):
tot_mem += self.saved_tensors[i].numel() * self.saved_tensors[i].element_size()
if tot_mem > self.opts.memory_limit:
self.saved_tensors = self.saved_tensors[i:]
return
def print_diagnostics(self):
if len(self.saved_tensors) == 0:
print("{name}: no stats".format(name=self.name))
return
if self.saved_tensors[0].ndim == 0:
# ensure there is at least one dim.
self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ]
ndim = self.saved_tensors[0].ndim
for dim in range(ndim):
print_diagnostics_for_dim(self.name, dim,
self.saved_tensors,
self.opts)
class ModelDiagnostic(object):
def __init__(self, opts: TensorDiagnosticOptions):
self.diagnostics = dict()
self.opts = opts
def __getitem__(self, name: str):
if name not in self.diagnostics:
self.diagnostics[name] = TensorDiagnostic(self.opts, name)
return self.diagnostics[name]
def print_diagnostics(self):
for k in sorted(self.diagnostics.keys()):
self.diagnostics[k].print_diagnostics()
def attach_diagnostics(model: nn.Module,
opts: TensorDiagnosticOptions) -> ModelDiagnostic:
ans = ModelDiagnostic(opts)
for name, module in model.named_modules():
if name == '':
name = "<top-level>"
forward_diagnostic = TensorDiagnostic(opts, name + ".output")
backward_diagnostic = TensorDiagnostic(opts, name + ".grad")
# setting model_diagnostic=ans and n=name below, instead of trying to capture the variables,
# ensures that we use the current values. (matters for name, since
# the variable gets overwritten). these closures don't really capture
# by value, only by "the final value the variable got in the function" :-(
def forward_hook(_module, _input, _output,
_model_diagnostic=ans, _name=name):
if isinstance(_output, Tensor):
_model_diagnostic[f"{_name}.output"].accumulate(_output)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o)
def backward_hook(_module, _input, _output,
_model_diagnostic=ans, _name=name):
if isinstance(_output, Tensor):
_model_diagnostic[f"{_name}.grad"].accumulate(_output)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o)
module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)
for name, parameter in model.named_parameters():
def param_backward_hook(grad,
_parameter=parameter,
_model_diagnostic=ans,
_name=name):
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
parameter.register_hook(param_backward_hook)
return ans
def _test_tensor_diagnostic():
opts = TensorDiagnosticOptions(2**20, True)
diagnostic = TensorDiagnostic(opts, "foo")
for _ in range(10):
diagnostic.accumulate(torch.randn(50, 100) * 10.0)
diagnostic.print_diagnostics()
model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
diagnostic = attach_diagnostics(model, opts)
for _ in range(10):
T = random.randint(200, 300)
x = torch.randn(T, 100)
y = model(x)
y.sum().backward()
diagnostic.print_diagnostics()
if __name__ == '__main__':
_test_tensor_diagnostic()
def _test_func():
ans = []
for i in range(10):
x = list()
x.append(i)
def func():
return x
ans.append(func)
return ans

View File

@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import diagnostics # ./diagnostics.py
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
@ -109,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/exp-100h-specaugmod_p0.9_0.15_fix", default="transducer_stateless/specaugmod_baseline",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -138,6 +139,13 @@ def get_parser():
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument(
"--print-diagnostics",
type=str2bool,
default=False,
help="Accumulate stats on activations, print them and exit.",
)
return parser return parser
@ -487,6 +495,9 @@ def train_one_epoch(
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
if params.print_diagnostics and batch_idx == 5:
return
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
@ -494,9 +505,6 @@ def train_one_epoch(
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}" f"tot_loss[{tot_loss}], batch size: {batch_size}"
) )
if batch_idx % params.log_interval == 0:
if tb_writer is not None: if tb_writer is not None:
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
@ -599,6 +607,11 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(2**22) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_clean_360_cuts()
@ -626,6 +639,7 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
model=model, model=model,
train_dl=train_dl, train_dl=train_dl,
@ -660,6 +674,10 @@ def run(rank, world_size, args):
world_size=world_size, world_size=world_size,
) )
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
save_checkpoint( save_checkpoint(
params=params, params=params,
model=model, model=model,