From 581786a6d367e7d9313c43ae12030bc6044c9d0c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Feb 2022 13:44:43 +0800 Subject: [PATCH] Adding diagnostics code... --- .../ASR/transducer_stateless/diagnostics.py | 284 ++++++++++++++++++ .../ASR/transducer_stateless/train.py | 40 ++- 2 files changed, 313 insertions(+), 11 deletions(-) create mode 100644 egs/librispeech/ASR/transducer_stateless/diagnostics.py diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py new file mode 100644 index 000000000..2dff91805 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -0,0 +1,284 @@ +import torch +from torch import Tensor +from torch import nn +import math +import random +from typing import Tuple, List + + +class TensorDiagnosticOptions(object): + """ + Options object for tensor diagnostics: + + Args: + memory_limit: the maximum number of bytes per tensor (limits how many copies + of the tensor we cache). + + """ + def __init__(self, memory_limit: int, + print_pos_ratio: bool = True): + self.memory_limit = memory_limit + self.print_pos_ratio = print_pos_ratio + + def dim_is_summarized(self, size: int): + return size > 10 and size != 31 + + def stats_types(self): + if self.print_pos_ratio: + return ["mean-abs", "pos-ratio"] + else: + return ["mean-abs"] + + + +def get_sum_abs_stats(x: Tensor, dim: int, + stats_type: str) -> Tuple[Tensor, int]: + """ + Returns the sum-of-absolute-value of this Tensor, for each + index into the specified axis/dim of the tensor. + Args: + x: Tensor, tensor to be analyzed + dim: dimension with 0 <= dim < x.ndim + stats_type: either "mean-abs" in which case the stats represent the + mean absolute value, or "pos-ratio" in which case the + stats represent the proportion of positive values (actually: + the tensor is count of positive values, count is the count of + all values). + Returns (sum_abs, count) + where sum_abs is a Tensor of shape (x.shape[dim],), and the count + is an integer saying how many items were counted in each element + of sum_abs. + """ + if stats_type == "mean-abs": + x = x.abs() + else: + assert stats_type == "pos-ratio" + x = (x > 0).to(dtype=torch.float) + orig_numel = x.numel() + sum_dims = [ d for d in range(x.ndim) if d != dim ] + x = torch.sum(x, dim=sum_dims) + count = orig_numel // x.numel() + x = x.flatten() + return x, count + +def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], + options: TensorDiagnosticOptions, + sizes_same: bool, + stats_type: str): + """ + This function gets diagnostics for a dimension of a module. + Args: + dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim + options: options object + sizes_same: true if all the tensor sizes are the same on this dimension + stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats + we accumulate, mean-abs is mean absolute value, "pos-ratio" + is proportion of positive to nonnegative values. + Returns: + Diagnostic as a string, either percentiles or the actual values, + see the code. + """ + # stats_and_counts is a list of pair (Tensor, int) + stats_and_counts = [ get_sum_abs_stats(x, dim, stats_type) for x in tensors ] + stats = [ x[0] for x in stats_and_counts ] + counts = [ x[1] for x in stats_and_counts ] + if sizes_same: + stats = torch.stack(stats).sum(dim=0) + count = sum(counts) + stats = stats / count + else: + stats = [ x[0] / x[1] for x in stats_and_counts ] + stats = torch.cat(stats, dim=0) + # if `summarize` we print percentiles of the stats; else, + # we print out individual elements. + summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) + if summarize: + # print out percentiles. + stats = stats.sort()[0] + num_percentiles = 10 + size = stats.numel() + percentiles = [] + for i in range(num_percentiles + 1): + index = (i * (size - 1)) // num_percentiles + percentiles.append(stats[index].item()) + percentiles = [ '%.2g' % x for x in percentiles ] + percentiles = ' '.join(percentiles) + return f'percentiles: [{percentiles}]' + else: + stats = stats.tolist() + stats = [ '%.2g' % x for x in stats ] + stats = '[' + ' '.join(stats) + ']' + return stats + + + +def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], + options: TensorDiagnosticOptions): + + for stats_type in options.stats_types(): + # stats_type will be "mean-abs" or "pos-ratio". + sizes = [ x.shape[dim] for x in tensors ] + sizes_same = all([ x == sizes[0] for x in sizes ]) + s = get_diagnostics_for_dim(dim, tensors, + options, sizes_same, + stats_type) + + min_size = min(sizes) + max_size = max(sizes) + size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}" + # stats_type will be "mean-abs" or "pos-ratio". + print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}") + + +class TensorDiagnostic(object): + """ + This class is not directly used by the user, it is responsible for collecting + diagnostics for a single parameter tensor of a torch.Module. + """ + def __init__(self, + opts: TensorDiagnosticOptions, + name: str): + self.name = name + self.opts = opts + self.saved_tensors = [] + + def accumulate(self, x): + if isinstance(x, Tuple): + x = x[0] + if not isinstance(x, Tensor): + return + if x.device == torch.device('cpu'): + x = x.detach().clone() + else: + x = x.detach().to('cpu', non_blocking=True) + self.saved_tensors.append(x) + l = len(self.saved_tensors) + if l & (l - 1) == 0: # power of 2.. + self._limit_memory() + + def _limit_memory(self): + if len(self.saved_tensors) > 1024: + self.saved_tensors = self.saved_tensors[-1024:] + return + + tot_mem = 0.0 + for i in reversed(range(len(self.saved_tensors))): + tot_mem += self.saved_tensors[i].numel() * self.saved_tensors[i].element_size() + if tot_mem > self.opts.memory_limit: + self.saved_tensors = self.saved_tensors[i:] + return + + def print_diagnostics(self): + if len(self.saved_tensors) == 0: + print("{name}: no stats".format(name=self.name)) + return + if self.saved_tensors[0].ndim == 0: + # ensure there is at least one dim. + self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ] + + ndim = self.saved_tensors[0].ndim + for dim in range(ndim): + print_diagnostics_for_dim(self.name, dim, + self.saved_tensors, + self.opts) + + +class ModelDiagnostic(object): + def __init__(self, opts: TensorDiagnosticOptions): + self.diagnostics = dict() + self.opts = opts + + def __getitem__(self, name: str): + if name not in self.diagnostics: + self.diagnostics[name] = TensorDiagnostic(self.opts, name) + return self.diagnostics[name] + + def print_diagnostics(self): + for k in sorted(self.diagnostics.keys()): + self.diagnostics[k].print_diagnostics() + + + +def attach_diagnostics(model: nn.Module, + opts: TensorDiagnosticOptions) -> ModelDiagnostic: + ans = ModelDiagnostic(opts) + for name, module in model.named_modules(): + if name == '': + name = "" + forward_diagnostic = TensorDiagnostic(opts, name + ".output") + backward_diagnostic = TensorDiagnostic(opts, name + ".grad") + + + # setting model_diagnostic=ans and n=name below, instead of trying to capture the variables, + # ensures that we use the current values. (matters for name, since + # the variable gets overwritten). these closures don't really capture + # by value, only by "the final value the variable got in the function" :-( + def forward_hook(_module, _input, _output, + _model_diagnostic=ans, _name=name): + if isinstance(_output, Tensor): + _model_diagnostic[f"{_name}.output"].accumulate(_output) + elif isinstance(_output, tuple): + for i, o in enumerate(_output): + _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o) + + def backward_hook(_module, _input, _output, + _model_diagnostic=ans, _name=name): + if isinstance(_output, Tensor): + _model_diagnostic[f"{_name}.grad"].accumulate(_output) + elif isinstance(_output, tuple): + for i, o in enumerate(_output): + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o) + + module.register_forward_hook(forward_hook) + module.register_backward_hook(backward_hook) + + for name, parameter in model.named_parameters(): + + def param_backward_hook(grad, + _parameter=parameter, + _model_diagnostic=ans, + _name=name): + _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter) + _model_diagnostic[f"{_name}.param_grad"].accumulate(grad) + + parameter.register_hook(param_backward_hook) + return ans + + + +def _test_tensor_diagnostic(): + opts = TensorDiagnosticOptions(2**20, True) + + diagnostic = TensorDiagnostic(opts, "foo") + + for _ in range(10): + diagnostic.accumulate(torch.randn(50, 100) * 10.0) + + diagnostic.print_diagnostics() + + model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80)) + + diagnostic = attach_diagnostics(model, opts) + for _ in range(10): + T = random.randint(200, 300) + x = torch.randn(T, 100) + y = model(x) + y.sum().backward() + + diagnostic.print_diagnostics() + + + +if __name__ == '__main__': + _test_tensor_diagnostic() + + +def _test_func(): + ans = [] + for i in range(10): + x = list() + x.append(i) + def func(): + return x + ans.append(func) + return ans diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 7d1d7ff08..0e1bbeaff 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import diagnostics # ./diagnostics.py from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -109,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp-100h-specaugmod_p0.9_0.15_fix", + default="transducer_stateless/specaugmod_baseline", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -138,6 +139,13 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + return parser @@ -487,6 +495,9 @@ def train_one_epoch( loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() + if params.print_diagnostics and batch_idx == 5: + return + if batch_idx % params.log_interval == 0: logging.info( @@ -494,9 +505,6 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}" ) - - if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train @@ -599,6 +607,11 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions(2**22) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: train_cuts += librispeech.train_clean_360_cuts() @@ -626,13 +639,14 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) @@ -660,6 +674,10 @@ def run(rank, world_size, args): world_size=world_size, ) + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + save_checkpoint( params=params, model=model,