From 4e23fb2252cafd94ad4c401abc310a5e56fc1f8b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 19 May 2022 11:45:59 +0800 Subject: [PATCH] Improve diagnostics code memory-wise and accumulate more stats. (#373) * Update diagnostics, hopefully print more stats. # Conflicts: # egs/librispeech/ASR/pruned_transducer_stateless4b/train.py * Remove memory-limit options arg * Remove unnecessary option for diagnostics code, collect on more batches --- .../ASR/pruned_transducer_stateless2/train.py | 7 +- .../ASR/pruned_transducer_stateless2/train.py | 7 +- .../ASR/pruned_transducer_stateless3/train.py | 7 +- .../ASR/pruned_transducer_stateless4/train.py | 7 +- .../ASR/transducer_stateless/train.py | 7 +- .../ASR/transducer_stateless2/train.py | 7 +- .../ASR/pruned_transducer_stateless2/train.py | 7 +- icefall/diagnostics.py | 296 +++++++----------- 8 files changed, 135 insertions(+), 210 deletions(-) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 4421ce2aa..83ae25561 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -689,7 +689,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() - if params.print_diagnostics and batch_idx == 5: + if params.print_diagnostics and batch_idx == 30: return if ( @@ -831,10 +831,7 @@ def run(rank, world_size, args): scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) + diagnostic = diagnostics.attach_diagnostics(model) gigaspeech = GigaSpeechAsrDataModule(args) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 51c1a231a..eed2df755 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -695,7 +695,7 @@ def train_one_epoch( display_and_save_batch(batch, params=params, sp=sp) raise - if params.print_diagnostics and batch_idx == 5: + if params.print_diagnostics and batch_idx == 30: return if ( @@ -839,10 +839,7 @@ def run(rank, world_size, args): scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) + diagnostic = diagnostics.attach_diagnostics(model) librispeech = LibriSpeechAsrDataModule(args) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 037f99bc7..f5a25a226 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -767,7 +767,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() - if params.print_diagnostics and batch_idx == 5: + if params.print_diagnostics and batch_idx == 30: return if ( @@ -938,10 +938,7 @@ def run(rank, world_size, args): scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) + diagnostic = diagnostics.attach_diagnostics(model) librispeech = LibriSpeech(manifest_dir=args.manifest_dir) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 4ff69d521..ca7207122 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -724,7 +724,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() - if params.print_diagnostics and batch_idx == 5: + if params.print_diagnostics and batch_idx == 30: return if ( @@ -888,10 +888,7 @@ def run(rank, world_size, args): scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) + diagnostic = diagnostics.attach_diagnostics(model) librispeech = LibriSpeechAsrDataModule(args) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 89f754b20..cb7f08a09 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -523,7 +523,7 @@ def train_one_epoch( loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - if params.print_diagnostics and batch_idx == 5: + if params.print_diagnostics and batch_idx == 30: return if batch_idx % params.log_interval == 0: @@ -635,10 +635,7 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) + diagnostic = diagnostics.attach_diagnostics(model) train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py index 8ceffb489..cb13e317c 100755 --- a/egs/librispeech/ASR/transducer_stateless2/train.py +++ b/egs/librispeech/ASR/transducer_stateless2/train.py @@ -511,7 +511,7 @@ def train_one_epoch( loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - if params.print_diagnostics and batch_idx == 5: + if params.print_diagnostics and batch_idx == 30: return if batch_idx % params.log_interval == 0: @@ -623,10 +623,7 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) + diagnostic = diagnostics.attach_diagnostics(model) train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py index 6c66bfb62..dda29b3e5 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py @@ -690,7 +690,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() - if params.print_diagnostics and batch_idx == 5: + if params.print_diagnostics and batch_idx == 30: return if ( @@ -832,10 +832,7 @@ def run(rank, world_size, args): scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) + diagnostic = diagnostics.attach_diagnostics(model) spgispeech = SPGISpeechAsrDataModule(args) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index bc8fe3069..1cd685d37 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -19,7 +19,7 @@ import random from typing import List, Optional, Tuple - +from dataclasses import dataclass import torch from torch import Tensor, nn @@ -28,16 +28,12 @@ class TensorDiagnosticOptions(object): """Options object for tensor diagnostics: Args: - memory_limit: - The maximum number of bytes per tensor - (limits how many copies of the tensor we cache). max_eig_dim: The maximum dimension for which we print out eigenvalues (limited for speed reasons). """ - def __init__(self, memory_limit: int = (2 ** 20), max_eig_dim: int = 512): - self.memory_limit = memory_limit + def __init__(self, max_eig_dim: int = 512): self.max_eig_dim = max_eig_dim def dim_is_summarized(self, size: int): @@ -94,138 +90,12 @@ def get_tensor_stats( return x, count -def get_diagnostics_for_dim( - dim: int, - tensors: List[Tensor], - options: TensorDiagnosticOptions, - sizes_same: bool, - stats_type: str, -) -> str: - """ - This function gets diagnostics for a dimension of a module. - - Args: - dim: - the dimension to analyze, with 0 <= dim < tensors[0].ndim - options: - options object - sizes_same: - True if all the tensor sizes are the same on this dimension - stats_type: either "abs" or "positive" or "eigs" or "value", - imdictates the type of stats we accumulate, abs is mean absolute - value, "positive" is proportion of positive to nonnegative values, - "eigs" is eigenvalues after doing outer product on this dim, sum - over all other dimes. - Returns: - Diagnostic as a string, either percentiles or the actual values, - see the code. Will return the empty string if the diagnostics did - not make sense to print out for this dimension, e.g. dimension - mismatch and stats_type == "eigs". - """ - - # stats_and_counts is a list of pair (Tensor, int) - stats_and_counts = [get_tensor_stats(x, dim, stats_type) for x in tensors] - stats = [x[0] for x in stats_and_counts] - counts = [x[1] for x in stats_and_counts] - - if stats_type == "eigs": - try: - stats = torch.stack(stats).sum(dim=0) - except: # noqa - return "" - count = sum(counts) - stats = stats / count - try: - eigs, _ = torch.symeig(stats) - stats = eigs.abs().sqrt() - except: # noqa - print("Error getting eigenvalues, trying another method.") - eigs = torch.linalg.eigvals(stats) - stats = eigs.abs().sqrt() - # sqrt so it reflects data magnitude, like stddev- not variance - elif sizes_same: - stats = torch.stack(stats).sum(dim=0) - count = sum(counts) - stats = stats / count - else: - stats = [x[0] / x[1] for x in stats_and_counts] - stats = torch.cat(stats, dim=0) - if stats_type == "rms": - stats = stats.sqrt() - - # if `summarize` we print percentiles of the stats; else, - # we print out individual elements. - summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) - if summarize: - # print out percentiles. - stats = stats.sort()[0] - num_percentiles = 10 - size = stats.numel() - percentiles = [] - for i in range(num_percentiles + 1): - index = (i * (size - 1)) // num_percentiles - percentiles.append(stats[index].item()) - percentiles = ["%.2g" % x for x in percentiles] - percentiles = " ".join(percentiles) - ans = f"percentiles: [{percentiles}]" - else: - ans = stats.tolist() - ans = ["%.2g" % x for x in ans] - ans = "[" + " ".join(ans) + "]" - if stats_type == "value": - # This norm is useful because it is strictly less than the largest - # sqrt(eigenvalue) of the variance, which we print out, and shows, - # speaking in an approximate way, how much of that largest eigenvalue - # can be attributed to the mean of the distribution. - norm = (stats ** 2).sum().sqrt().item() - mean = stats.mean().item() - rms = (stats ** 2).mean().sqrt().item() - ans += f", norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}" - else: - mean = stats.mean().item() - rms = (stats ** 2).mean().sqrt().item() - ans += f", mean={mean:.2g}, rms={rms:.2g}" - return ans -def print_diagnostics_for_dim( - name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions -): - """This function prints diagnostics for a dimension of a tensor. - - Args: - name: - The tensor name. - dim: - The dimension to analyze, with 0 <= dim < tensors[0].ndim. - tensors: - List of cached tensors to get the stats. - options: - Options object. - """ - - ndim = tensors[0].ndim - if ndim > 1: - stats_types = ["abs", "positive", "value", "rms"] - if tensors[0].shape[dim] <= options.max_eig_dim: - stats_types.append("eigs") - else: - stats_types = ["value", "abs"] - - for stats_type in stats_types: - sizes = [x.shape[dim] for x in tensors] - sizes_same = all([x == sizes[0] for x in sizes]) - s = get_diagnostics_for_dim( - dim, tensors, options, sizes_same, stats_type - ) - if s == "": - continue - - min_size = min(sizes) - max_size = max(sizes) - size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}" - # stats_type will be "abs" or "positive". - print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}") +@dataclass +class TensorAndCount: + tensor: Tensor + count: int class TensorDiagnostic(object): @@ -238,12 +108,23 @@ class TensorDiagnostic(object): name: The tensor name. """ - def __init__(self, opts: TensorDiagnosticOptions, name: str): self.name = name self.opts = opts - # A list to cache the tensors. - self.saved_tensors = [] + + + self.stats = None # we'll later assign a list to this data member. It's a list of dict. + + # the keys into self.stats[dim] are strings, whose values can be + # "abs", "value", "positive", "rms", "value". + # The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount, + # containing a tensor and its associated count (which is the sum of the other dims + # that we aggregated over, e.g. the number of frames and/or batch elements and/or + # channels. + # ... we actually accumulate the Tensors / counts any time we have the same-dim tensor, + # only adding a new element to the list if there was a different dim. + # if the string in the key is "eigs", if we detect a length mismatch we put None as the value. + def accumulate(self, x): """Accumulate tensors.""" @@ -251,50 +132,115 @@ class TensorDiagnostic(object): x = x[0] if not isinstance(x, Tensor): return - if x.device == torch.device("cpu"): - x = x.detach().clone() - else: - x = x.detach().to("cpu", non_blocking=True) - self.saved_tensors.append(x) - num = len(self.saved_tensors) - if num & (num - 1) == 0: # power of 2.. - self._limit_memory() + x = x.detach().clone() + if x.ndim == 0: + x = x.unsqueeze(0) + ndim = x.ndim + if self.stats is None: + self.stats = [ dict() for _ in range(ndim) ] - def _limit_memory(self): - """Only keep the newly cached tensors to limit memory.""" - if len(self.saved_tensors) > 1024: - self.saved_tensors = self.saved_tensors[-1024:] - return + for dim in range(ndim): + this_dim_stats = self.stats[dim] + if ndim > 1: + stats_types = ["abs", "positive", "value", "rms"] + if x.shape[dim] <= self.opts.max_eig_dim: + stats_types.append("eigs") + else: + stats_types = ["value", "abs"] + this_dict = self.stats[dim] + for stats_type in stats_types: + stats, count = get_tensor_stats(x, dim, stats_type) + if not stats_type in this_dim_stats: + this_dim_stats[stats_type] = [] # list of TensorAndCount + + done = False + if this_dim_stats[stats_type] is None: + # we can reach here if we detected for stats_type "eigs" that + # where was more than one different size for this dim. Then we + # disable accumulating this stats type, as it uses too much memory. + continue + for s in this_dim_stats[stats_type]: + if s.tensor.shape == stats.shape: + s.tensor += stats + s.count += count + done = True + break + if not done: + if this_dim_stats[stats_type] != [] and stats_type == "eigs": + # >1 size encountered on this dim, e.g. it's a batch or time dimension, + # don't accumulat "eigs" stats type, it uses too much memory + this_dim_stats[stats_type] = None + else: + this_dim_stats[stats_type].append(TensorAndCount(stats, count)) - tot_mem = 0.0 - for i in reversed(range(len(self.saved_tensors))): - tot_mem += ( - self.saved_tensors[i].numel() - * self.saved_tensors[i].element_size() - ) - if tot_mem > self.opts.memory_limit: - self.saved_tensors = self.saved_tensors[i:] - return def print_diagnostics(self): """Print diagnostics for each dimension of the tensor.""" - if len(self.saved_tensors) == 0: - print("{name}: no stats".format(name=self.name)) - return + for dim, this_dim_stats in enumerate(self.stats): + for stats_type, stats_list in this_dim_stats.items(): + # stats_type could be "rms", "value", "abs", "eigs", "positive". + # "value" could be a list of TensorAndCount, or None + if stats_list is None: + assert stats_type == "eigs" + continue - if self.saved_tensors[0].ndim == 0: - # Ensure there is at least one dim. - self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors] + if stats_type == "eigs": + assert len(stats_list) == 1 + stats = stats_list[0].tensor / stats_list[0].count + try: + eigs, _ = torch.symeig(stats) + stats = eigs.abs().sqrt() + except: # noqa + print("Error getting eigenvalues, trying another method.") + eigs = torch.linalg.eigvals(stats) + stats = eigs.abs().sqrt() + # sqrt so it reflects data magnitude, like stddev- not variance + elif len(stats_list) == 1: + stats = stats_list[0].tensor / stats_list[0].count + else: + stats = torch.cat([x.tensor / x.count for x in stats_list], dim=0) - try: - device = torch.device("cuda") - except: # noqa - device = torch.device("cpu") + if stats_type == "rms": + # we stored the square; after aggregation we need to take sqrt. + stats = stats.sqrt() + + # if `summarize` we print percentiles of the stats; else, + # we print out individual elements. + summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(stats.numel()) + if summarize: # usually `summarize` will be true + # print out percentiles. + stats = stats.sort()[0] + num_percentiles = 10 + size = stats.numel() + percentiles = [] + for i in range(num_percentiles + 1): + index = (i * (size - 1)) // num_percentiles + percentiles.append(stats[index].item()) + percentiles = ["%.2g" % x for x in percentiles] + percentiles = " ".join(percentiles) + ans = f"percentiles: [{percentiles}]" + else: + ans = stats.tolist() + ans = ["%.2g" % x for x in ans] + ans = "[" + " ".join(ans) + "]" + if stats_type == "value": + # This norm is useful because it is strictly less than the largest + # sqrt(eigenvalue) of the variance, which we print out, and shows, + # speaking in an approximate way, how much of that largest eigenvalue + # can be attributed to the mean of the distribution. + norm = (stats ** 2).sum().sqrt().item() + ans += f", norm={norm:.2g}" + mean = stats.mean().item() + rms = (stats ** 2).mean().sqrt().item() + ans += f", mean={mean:.2g}, rms={rms:.2g}" + + # OK, "ans" contains the actual stats, e.g. + # ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5" + + sizes = [x.tensor.shape[0] for x in stats_list] + size_str = f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}" + print(f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}") - ndim = self.saved_tensors[0].ndim - tensors = [x.to(device) for x in self.saved_tensors] - for dim in range(ndim): - print_diagnostics_for_dim(self.name, dim, tensors, self.opts) class ModelDiagnostic(object):