Improve diagnostics code memory-wise and accumulate more stats. (#373)

* Update diagnostics, hopefully print more stats.

# Conflicts:
#	egs/librispeech/ASR/pruned_transducer_stateless4b/train.py

* Remove memory-limit options arg

* Remove unnecessary option for diagnostics code, collect on more batches
This commit is contained in:
Daniel Povey 2022-05-19 11:45:59 +08:00 committed by GitHub
parent f6ce135608
commit 4e23fb2252
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 135 additions and 210 deletions

View File

@ -689,7 +689,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if ( if (
@ -831,10 +831,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
gigaspeech = GigaSpeechAsrDataModule(args) gigaspeech = GigaSpeechAsrDataModule(args)

View File

@ -695,7 +695,7 @@ def train_one_epoch(
display_and_save_batch(batch, params=params, sp=sp) display_and_save_batch(batch, params=params, sp=sp)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if ( if (
@ -839,10 +839,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)

View File

@ -767,7 +767,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if ( if (
@ -938,10 +938,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeech(manifest_dir=args.manifest_dir) librispeech = LibriSpeech(manifest_dir=args.manifest_dir)

View File

@ -724,7 +724,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if ( if (
@ -888,10 +888,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)

View File

@ -523,7 +523,7 @@ def train_one_epoch(
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
@ -635,10 +635,7 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:

View File

@ -511,7 +511,7 @@ def train_one_epoch(
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
@ -623,10 +623,7 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:

View File

@ -690,7 +690,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if ( if (
@ -832,10 +832,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
spgispeech = SPGISpeechAsrDataModule(args) spgispeech = SPGISpeechAsrDataModule(args)

View File

@ -19,7 +19,7 @@
import random import random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from dataclasses import dataclass
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -28,16 +28,12 @@ class TensorDiagnosticOptions(object):
"""Options object for tensor diagnostics: """Options object for tensor diagnostics:
Args: Args:
memory_limit:
The maximum number of bytes per tensor
(limits how many copies of the tensor we cache).
max_eig_dim: max_eig_dim:
The maximum dimension for which we print out eigenvalues The maximum dimension for which we print out eigenvalues
(limited for speed reasons). (limited for speed reasons).
""" """
def __init__(self, memory_limit: int = (2 ** 20), max_eig_dim: int = 512): def __init__(self, max_eig_dim: int = 512):
self.memory_limit = memory_limit
self.max_eig_dim = max_eig_dim self.max_eig_dim = max_eig_dim
def dim_is_summarized(self, size: int): def dim_is_summarized(self, size: int):
@ -94,138 +90,12 @@ def get_tensor_stats(
return x, count return x, count
def get_diagnostics_for_dim(
dim: int,
tensors: List[Tensor],
options: TensorDiagnosticOptions,
sizes_same: bool,
stats_type: str,
) -> str:
"""
This function gets diagnostics for a dimension of a module.
Args:
dim:
the dimension to analyze, with 0 <= dim < tensors[0].ndim
options:
options object
sizes_same:
True if all the tensor sizes are the same on this dimension
stats_type: either "abs" or "positive" or "eigs" or "value",
imdictates the type of stats we accumulate, abs is mean absolute
value, "positive" is proportion of positive to nonnegative values,
"eigs" is eigenvalues after doing outer product on this dim, sum
over all other dimes.
Returns:
Diagnostic as a string, either percentiles or the actual values,
see the code. Will return the empty string if the diagnostics did
not make sense to print out for this dimension, e.g. dimension
mismatch and stats_type == "eigs".
"""
# stats_and_counts is a list of pair (Tensor, int)
stats_and_counts = [get_tensor_stats(x, dim, stats_type) for x in tensors]
stats = [x[0] for x in stats_and_counts]
counts = [x[1] for x in stats_and_counts]
if stats_type == "eigs":
try:
stats = torch.stack(stats).sum(dim=0)
except: # noqa
return ""
count = sum(counts)
stats = stats / count
try:
eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except: # noqa
print("Error getting eigenvalues, trying another method.")
eigs = torch.linalg.eigvals(stats)
stats = eigs.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance
elif sizes_same:
stats = torch.stack(stats).sum(dim=0)
count = sum(counts)
stats = stats / count
else:
stats = [x[0] / x[1] for x in stats_and_counts]
stats = torch.cat(stats, dim=0)
if stats_type == "rms":
stats = stats.sqrt()
# if `summarize` we print percentiles of the stats; else,
# we print out individual elements.
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
if summarize:
# print out percentiles.
stats = stats.sort()[0]
num_percentiles = 10
size = stats.numel()
percentiles = []
for i in range(num_percentiles + 1):
index = (i * (size - 1)) // num_percentiles
percentiles.append(stats[index].item())
percentiles = ["%.2g" % x for x in percentiles]
percentiles = " ".join(percentiles)
ans = f"percentiles: [{percentiles}]"
else:
ans = stats.tolist()
ans = ["%.2g" % x for x in ans]
ans = "[" + " ".join(ans) + "]"
if stats_type == "value":
# This norm is useful because it is strictly less than the largest
# sqrt(eigenvalue) of the variance, which we print out, and shows,
# speaking in an approximate way, how much of that largest eigenvalue
# can be attributed to the mean of the distribution.
norm = (stats ** 2).sum().sqrt().item()
mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item()
ans += f", norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}"
else:
mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item()
ans += f", mean={mean:.2g}, rms={rms:.2g}"
return ans
def print_diagnostics_for_dim( @dataclass
name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions class TensorAndCount:
): tensor: Tensor
"""This function prints diagnostics for a dimension of a tensor. count: int
Args:
name:
The tensor name.
dim:
The dimension to analyze, with 0 <= dim < tensors[0].ndim.
tensors:
List of cached tensors to get the stats.
options:
Options object.
"""
ndim = tensors[0].ndim
if ndim > 1:
stats_types = ["abs", "positive", "value", "rms"]
if tensors[0].shape[dim] <= options.max_eig_dim:
stats_types.append("eigs")
else:
stats_types = ["value", "abs"]
for stats_type in stats_types:
sizes = [x.shape[dim] for x in tensors]
sizes_same = all([x == sizes[0] for x in sizes])
s = get_diagnostics_for_dim(
dim, tensors, options, sizes_same, stats_type
)
if s == "":
continue
min_size = min(sizes)
max_size = max(sizes)
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
# stats_type will be "abs" or "positive".
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
class TensorDiagnostic(object): class TensorDiagnostic(object):
@ -238,12 +108,23 @@ class TensorDiagnostic(object):
name: name:
The tensor name. The tensor name.
""" """
def __init__(self, opts: TensorDiagnosticOptions, name: str): def __init__(self, opts: TensorDiagnosticOptions, name: str):
self.name = name self.name = name
self.opts = opts self.opts = opts
# A list to cache the tensors.
self.saved_tensors = []
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
# the keys into self.stats[dim] are strings, whose values can be
# "abs", "value", "positive", "rms", "value".
# The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount,
# containing a tensor and its associated count (which is the sum of the other dims
# that we aggregated over, e.g. the number of frames and/or batch elements and/or
# channels.
# ... we actually accumulate the Tensors / counts any time we have the same-dim tensor,
# only adding a new element to the list if there was a different dim.
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
def accumulate(self, x): def accumulate(self, x):
"""Accumulate tensors.""" """Accumulate tensors."""
@ -251,50 +132,115 @@ class TensorDiagnostic(object):
x = x[0] x = x[0]
if not isinstance(x, Tensor): if not isinstance(x, Tensor):
return return
if x.device == torch.device("cpu"): x = x.detach().clone()
x = x.detach().clone() if x.ndim == 0:
else: x = x.unsqueeze(0)
x = x.detach().to("cpu", non_blocking=True) ndim = x.ndim
self.saved_tensors.append(x) if self.stats is None:
num = len(self.saved_tensors) self.stats = [ dict() for _ in range(ndim) ]
if num & (num - 1) == 0: # power of 2..
self._limit_memory()
def _limit_memory(self): for dim in range(ndim):
"""Only keep the newly cached tensors to limit memory.""" this_dim_stats = self.stats[dim]
if len(self.saved_tensors) > 1024: if ndim > 1:
self.saved_tensors = self.saved_tensors[-1024:] stats_types = ["abs", "positive", "value", "rms"]
return if x.shape[dim] <= self.opts.max_eig_dim:
stats_types.append("eigs")
else:
stats_types = ["value", "abs"]
this_dict = self.stats[dim]
for stats_type in stats_types:
stats, count = get_tensor_stats(x, dim, stats_type)
if not stats_type in this_dim_stats:
this_dim_stats[stats_type] = [] # list of TensorAndCount
done = False
if this_dim_stats[stats_type] is None:
# we can reach here if we detected for stats_type "eigs" that
# where was more than one different size for this dim. Then we
# disable accumulating this stats type, as it uses too much memory.
continue
for s in this_dim_stats[stats_type]:
if s.tensor.shape == stats.shape:
s.tensor += stats
s.count += count
done = True
break
if not done:
if this_dim_stats[stats_type] != [] and stats_type == "eigs":
# >1 size encountered on this dim, e.g. it's a batch or time dimension,
# don't accumulat "eigs" stats type, it uses too much memory
this_dim_stats[stats_type] = None
else:
this_dim_stats[stats_type].append(TensorAndCount(stats, count))
tot_mem = 0.0
for i in reversed(range(len(self.saved_tensors))):
tot_mem += (
self.saved_tensors[i].numel()
* self.saved_tensors[i].element_size()
)
if tot_mem > self.opts.memory_limit:
self.saved_tensors = self.saved_tensors[i:]
return
def print_diagnostics(self): def print_diagnostics(self):
"""Print diagnostics for each dimension of the tensor.""" """Print diagnostics for each dimension of the tensor."""
if len(self.saved_tensors) == 0: for dim, this_dim_stats in enumerate(self.stats):
print("{name}: no stats".format(name=self.name)) for stats_type, stats_list in this_dim_stats.items():
return # stats_type could be "rms", "value", "abs", "eigs", "positive".
# "value" could be a list of TensorAndCount, or None
if stats_list is None:
assert stats_type == "eigs"
continue
if self.saved_tensors[0].ndim == 0: if stats_type == "eigs":
# Ensure there is at least one dim. assert len(stats_list) == 1
self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors] stats = stats_list[0].tensor / stats_list[0].count
try:
eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except: # noqa
print("Error getting eigenvalues, trying another method.")
eigs = torch.linalg.eigvals(stats)
stats = eigs.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance
elif len(stats_list) == 1:
stats = stats_list[0].tensor / stats_list[0].count
else:
stats = torch.cat([x.tensor / x.count for x in stats_list], dim=0)
try: if stats_type == "rms":
device = torch.device("cuda") # we stored the square; after aggregation we need to take sqrt.
except: # noqa stats = stats.sqrt()
device = torch.device("cpu")
# if `summarize` we print percentiles of the stats; else,
# we print out individual elements.
summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(stats.numel())
if summarize: # usually `summarize` will be true
# print out percentiles.
stats = stats.sort()[0]
num_percentiles = 10
size = stats.numel()
percentiles = []
for i in range(num_percentiles + 1):
index = (i * (size - 1)) // num_percentiles
percentiles.append(stats[index].item())
percentiles = ["%.2g" % x for x in percentiles]
percentiles = " ".join(percentiles)
ans = f"percentiles: [{percentiles}]"
else:
ans = stats.tolist()
ans = ["%.2g" % x for x in ans]
ans = "[" + " ".join(ans) + "]"
if stats_type == "value":
# This norm is useful because it is strictly less than the largest
# sqrt(eigenvalue) of the variance, which we print out, and shows,
# speaking in an approximate way, how much of that largest eigenvalue
# can be attributed to the mean of the distribution.
norm = (stats ** 2).sum().sqrt().item()
ans += f", norm={norm:.2g}"
mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item()
ans += f", mean={mean:.2g}, rms={rms:.2g}"
# OK, "ans" contains the actual stats, e.g.
# ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"
sizes = [x.tensor.shape[0] for x in stats_list]
size_str = f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
print(f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}")
ndim = self.saved_tensors[0].ndim
tensors = [x.to(device) for x in self.saved_tensors]
for dim in range(ndim):
print_diagnostics_for_dim(self.name, dim, tensors, self.opts)
class ModelDiagnostic(object): class ModelDiagnostic(object):