Improve diagnostics code memory-wise and accumulate more stats. (#373)

* Update diagnostics, hopefully print more stats.

# Conflicts:
#	egs/librispeech/ASR/pruned_transducer_stateless4b/train.py

* Remove memory-limit options arg

* Remove unnecessary option for diagnostics code, collect on more batches
This commit is contained in:
Daniel Povey 2022-05-19 11:45:59 +08:00 committed by GitHub
parent f6ce135608
commit 4e23fb2252
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 135 additions and 210 deletions

View File

@ -689,7 +689,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5:
if params.print_diagnostics and batch_idx == 30:
return
if (
@ -831,10 +831,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diagnostic = diagnostics.attach_diagnostics(model)
gigaspeech = GigaSpeechAsrDataModule(args)

View File

@ -695,7 +695,7 @@ def train_one_epoch(
display_and_save_batch(batch, params=params, sp=sp)
raise
if params.print_diagnostics and batch_idx == 5:
if params.print_diagnostics and batch_idx == 30:
return
if (
@ -839,10 +839,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diagnostic = diagnostics.attach_diagnostics(model)
librispeech = LibriSpeechAsrDataModule(args)

View File

@ -767,7 +767,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5:
if params.print_diagnostics and batch_idx == 30:
return
if (
@ -938,10 +938,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diagnostic = diagnostics.attach_diagnostics(model)
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)

View File

@ -724,7 +724,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5:
if params.print_diagnostics and batch_idx == 30:
return
if (
@ -888,10 +888,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diagnostic = diagnostics.attach_diagnostics(model)
librispeech = LibriSpeechAsrDataModule(args)

View File

@ -523,7 +523,7 @@ def train_one_epoch(
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if params.print_diagnostics and batch_idx == 5:
if params.print_diagnostics and batch_idx == 30:
return
if batch_idx % params.log_interval == 0:
@ -635,10 +635,7 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diagnostic = diagnostics.attach_diagnostics(model)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:

View File

@ -511,7 +511,7 @@ def train_one_epoch(
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if params.print_diagnostics and batch_idx == 5:
if params.print_diagnostics and batch_idx == 30:
return
if batch_idx % params.log_interval == 0:
@ -623,10 +623,7 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diagnostic = diagnostics.attach_diagnostics(model)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:

View File

@ -690,7 +690,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5:
if params.print_diagnostics and batch_idx == 30:
return
if (
@ -832,10 +832,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diagnostic = diagnostics.attach_diagnostics(model)
spgispeech = SPGISpeechAsrDataModule(args)

View File

@ -19,7 +19,7 @@
import random
from typing import List, Optional, Tuple
from dataclasses import dataclass
import torch
from torch import Tensor, nn
@ -28,16 +28,12 @@ class TensorDiagnosticOptions(object):
"""Options object for tensor diagnostics:
Args:
memory_limit:
The maximum number of bytes per tensor
(limits how many copies of the tensor we cache).
max_eig_dim:
The maximum dimension for which we print out eigenvalues
(limited for speed reasons).
"""
def __init__(self, memory_limit: int = (2 ** 20), max_eig_dim: int = 512):
self.memory_limit = memory_limit
def __init__(self, max_eig_dim: int = 512):
self.max_eig_dim = max_eig_dim
def dim_is_summarized(self, size: int):
@ -94,138 +90,12 @@ def get_tensor_stats(
return x, count
def get_diagnostics_for_dim(
dim: int,
tensors: List[Tensor],
options: TensorDiagnosticOptions,
sizes_same: bool,
stats_type: str,
) -> str:
"""
This function gets diagnostics for a dimension of a module.
Args:
dim:
the dimension to analyze, with 0 <= dim < tensors[0].ndim
options:
options object
sizes_same:
True if all the tensor sizes are the same on this dimension
stats_type: either "abs" or "positive" or "eigs" or "value",
imdictates the type of stats we accumulate, abs is mean absolute
value, "positive" is proportion of positive to nonnegative values,
"eigs" is eigenvalues after doing outer product on this dim, sum
over all other dimes.
Returns:
Diagnostic as a string, either percentiles or the actual values,
see the code. Will return the empty string if the diagnostics did
not make sense to print out for this dimension, e.g. dimension
mismatch and stats_type == "eigs".
"""
# stats_and_counts is a list of pair (Tensor, int)
stats_and_counts = [get_tensor_stats(x, dim, stats_type) for x in tensors]
stats = [x[0] for x in stats_and_counts]
counts = [x[1] for x in stats_and_counts]
if stats_type == "eigs":
try:
stats = torch.stack(stats).sum(dim=0)
except: # noqa
return ""
count = sum(counts)
stats = stats / count
try:
eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except: # noqa
print("Error getting eigenvalues, trying another method.")
eigs = torch.linalg.eigvals(stats)
stats = eigs.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance
elif sizes_same:
stats = torch.stack(stats).sum(dim=0)
count = sum(counts)
stats = stats / count
else:
stats = [x[0] / x[1] for x in stats_and_counts]
stats = torch.cat(stats, dim=0)
if stats_type == "rms":
stats = stats.sqrt()
# if `summarize` we print percentiles of the stats; else,
# we print out individual elements.
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
if summarize:
# print out percentiles.
stats = stats.sort()[0]
num_percentiles = 10
size = stats.numel()
percentiles = []
for i in range(num_percentiles + 1):
index = (i * (size - 1)) // num_percentiles
percentiles.append(stats[index].item())
percentiles = ["%.2g" % x for x in percentiles]
percentiles = " ".join(percentiles)
ans = f"percentiles: [{percentiles}]"
else:
ans = stats.tolist()
ans = ["%.2g" % x for x in ans]
ans = "[" + " ".join(ans) + "]"
if stats_type == "value":
# This norm is useful because it is strictly less than the largest
# sqrt(eigenvalue) of the variance, which we print out, and shows,
# speaking in an approximate way, how much of that largest eigenvalue
# can be attributed to the mean of the distribution.
norm = (stats ** 2).sum().sqrt().item()
mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item()
ans += f", norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}"
else:
mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item()
ans += f", mean={mean:.2g}, rms={rms:.2g}"
return ans
def print_diagnostics_for_dim(
name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions
):
"""This function prints diagnostics for a dimension of a tensor.
Args:
name:
The tensor name.
dim:
The dimension to analyze, with 0 <= dim < tensors[0].ndim.
tensors:
List of cached tensors to get the stats.
options:
Options object.
"""
ndim = tensors[0].ndim
if ndim > 1:
stats_types = ["abs", "positive", "value", "rms"]
if tensors[0].shape[dim] <= options.max_eig_dim:
stats_types.append("eigs")
else:
stats_types = ["value", "abs"]
for stats_type in stats_types:
sizes = [x.shape[dim] for x in tensors]
sizes_same = all([x == sizes[0] for x in sizes])
s = get_diagnostics_for_dim(
dim, tensors, options, sizes_same, stats_type
)
if s == "":
continue
min_size = min(sizes)
max_size = max(sizes)
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
# stats_type will be "abs" or "positive".
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
@dataclass
class TensorAndCount:
tensor: Tensor
count: int
class TensorDiagnostic(object):
@ -238,12 +108,23 @@ class TensorDiagnostic(object):
name:
The tensor name.
"""
def __init__(self, opts: TensorDiagnosticOptions, name: str):
self.name = name
self.opts = opts
# A list to cache the tensors.
self.saved_tensors = []
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
# the keys into self.stats[dim] are strings, whose values can be
# "abs", "value", "positive", "rms", "value".
# The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount,
# containing a tensor and its associated count (which is the sum of the other dims
# that we aggregated over, e.g. the number of frames and/or batch elements and/or
# channels.
# ... we actually accumulate the Tensors / counts any time we have the same-dim tensor,
# only adding a new element to the list if there was a different dim.
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
def accumulate(self, x):
"""Accumulate tensors."""
@ -251,50 +132,115 @@ class TensorDiagnostic(object):
x = x[0]
if not isinstance(x, Tensor):
return
if x.device == torch.device("cpu"):
x = x.detach().clone()
else:
x = x.detach().to("cpu", non_blocking=True)
self.saved_tensors.append(x)
num = len(self.saved_tensors)
if num & (num - 1) == 0: # power of 2..
self._limit_memory()
x = x.detach().clone()
if x.ndim == 0:
x = x.unsqueeze(0)
ndim = x.ndim
if self.stats is None:
self.stats = [ dict() for _ in range(ndim) ]
def _limit_memory(self):
"""Only keep the newly cached tensors to limit memory."""
if len(self.saved_tensors) > 1024:
self.saved_tensors = self.saved_tensors[-1024:]
return
for dim in range(ndim):
this_dim_stats = self.stats[dim]
if ndim > 1:
stats_types = ["abs", "positive", "value", "rms"]
if x.shape[dim] <= self.opts.max_eig_dim:
stats_types.append("eigs")
else:
stats_types = ["value", "abs"]
this_dict = self.stats[dim]
for stats_type in stats_types:
stats, count = get_tensor_stats(x, dim, stats_type)
if not stats_type in this_dim_stats:
this_dim_stats[stats_type] = [] # list of TensorAndCount
done = False
if this_dim_stats[stats_type] is None:
# we can reach here if we detected for stats_type "eigs" that
# where was more than one different size for this dim. Then we
# disable accumulating this stats type, as it uses too much memory.
continue
for s in this_dim_stats[stats_type]:
if s.tensor.shape == stats.shape:
s.tensor += stats
s.count += count
done = True
break
if not done:
if this_dim_stats[stats_type] != [] and stats_type == "eigs":
# >1 size encountered on this dim, e.g. it's a batch or time dimension,
# don't accumulat "eigs" stats type, it uses too much memory
this_dim_stats[stats_type] = None
else:
this_dim_stats[stats_type].append(TensorAndCount(stats, count))
tot_mem = 0.0
for i in reversed(range(len(self.saved_tensors))):
tot_mem += (
self.saved_tensors[i].numel()
* self.saved_tensors[i].element_size()
)
if tot_mem > self.opts.memory_limit:
self.saved_tensors = self.saved_tensors[i:]
return
def print_diagnostics(self):
"""Print diagnostics for each dimension of the tensor."""
if len(self.saved_tensors) == 0:
print("{name}: no stats".format(name=self.name))
return
for dim, this_dim_stats in enumerate(self.stats):
for stats_type, stats_list in this_dim_stats.items():
# stats_type could be "rms", "value", "abs", "eigs", "positive".
# "value" could be a list of TensorAndCount, or None
if stats_list is None:
assert stats_type == "eigs"
continue
if self.saved_tensors[0].ndim == 0:
# Ensure there is at least one dim.
self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors]
if stats_type == "eigs":
assert len(stats_list) == 1
stats = stats_list[0].tensor / stats_list[0].count
try:
eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except: # noqa
print("Error getting eigenvalues, trying another method.")
eigs = torch.linalg.eigvals(stats)
stats = eigs.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance
elif len(stats_list) == 1:
stats = stats_list[0].tensor / stats_list[0].count
else:
stats = torch.cat([x.tensor / x.count for x in stats_list], dim=0)
try:
device = torch.device("cuda")
except: # noqa
device = torch.device("cpu")
if stats_type == "rms":
# we stored the square; after aggregation we need to take sqrt.
stats = stats.sqrt()
# if `summarize` we print percentiles of the stats; else,
# we print out individual elements.
summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(stats.numel())
if summarize: # usually `summarize` will be true
# print out percentiles.
stats = stats.sort()[0]
num_percentiles = 10
size = stats.numel()
percentiles = []
for i in range(num_percentiles + 1):
index = (i * (size - 1)) // num_percentiles
percentiles.append(stats[index].item())
percentiles = ["%.2g" % x for x in percentiles]
percentiles = " ".join(percentiles)
ans = f"percentiles: [{percentiles}]"
else:
ans = stats.tolist()
ans = ["%.2g" % x for x in ans]
ans = "[" + " ".join(ans) + "]"
if stats_type == "value":
# This norm is useful because it is strictly less than the largest
# sqrt(eigenvalue) of the variance, which we print out, and shows,
# speaking in an approximate way, how much of that largest eigenvalue
# can be attributed to the mean of the distribution.
norm = (stats ** 2).sum().sqrt().item()
ans += f", norm={norm:.2g}"
mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item()
ans += f", mean={mean:.2g}, rms={rms:.2g}"
# OK, "ans" contains the actual stats, e.g.
# ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"
sizes = [x.tensor.shape[0] for x in stats_list]
size_str = f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
print(f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}")
ndim = self.saved_tensors[0].ndim
tensors = [x.to(device) for x in self.saved_tensors]
for dim in range(ndim):
print_diagnostics_for_dim(self.name, dim, tensors, self.opts)
class ModelDiagnostic(object):