Update diagnostics, hopefully print more stats.

# Conflicts:
#	egs/librispeech/ASR/pruned_transducer_stateless4b/train.py
This commit is contained in:
Daniel Povey 2022-05-18 21:42:57 +08:00
parent f6ce135608
commit c2c46ea023

View File

@ -19,7 +19,7 @@
import random
from typing import List, Optional, Tuple
from dataclasses import dataclass
import torch
from torch import Tensor, nn
@ -94,47 +94,103 @@ def get_tensor_stats(
return x, count
def get_diagnostics_for_dim(
dim: int,
tensors: List[Tensor],
options: TensorDiagnosticOptions,
sizes_same: bool,
stats_type: str,
) -> str:
"""
This function gets diagnostics for a dimension of a module.
@dataclass
class TensorAndCount:
tensor: Tensor
count: int
class TensorDiagnostic(object):
"""This class is not directly used by the user, it is responsible for
collecting diagnostics for a single parameter tensor of a torch.nn.Module.
Args:
dim:
the dimension to analyze, with 0 <= dim < tensors[0].ndim
options:
options object
sizes_same:
True if all the tensor sizes are the same on this dimension
stats_type: either "abs" or "positive" or "eigs" or "value",
imdictates the type of stats we accumulate, abs is mean absolute
value, "positive" is proportion of positive to nonnegative values,
"eigs" is eigenvalues after doing outer product on this dim, sum
over all other dimes.
Returns:
Diagnostic as a string, either percentiles or the actual values,
see the code. Will return the empty string if the diagnostics did
not make sense to print out for this dimension, e.g. dimension
mismatch and stats_type == "eigs".
opts:
Options object.
name:
The tensor name.
"""
def __init__(self, opts: TensorDiagnosticOptions, name: str):
self.name = name
self.opts = opts
# stats_and_counts is a list of pair (Tensor, int)
stats_and_counts = [get_tensor_stats(x, dim, stats_type) for x in tensors]
stats = [x[0] for x in stats_and_counts]
counts = [x[1] for x in stats_and_counts]
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
# the keys into self.stats[dim] are strings, whose values can be
# "abs", "value", "positive", "rms", "value".
# The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount,
# containing a tensor and its associated count (which is the sum of the other dims
# that we aggregated over, e.g. the number of frames and/or batch elements and/or
# channels.
# ... we actually accumulate the Tensors / counts any time we have the same-dim tensor,
# only adding a new element to the list if there was a different dim.
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
def accumulate(self, x):
"""Accumulate tensors."""
if isinstance(x, Tuple):
x = x[0]
if not isinstance(x, Tensor):
return
x = x.detach().clone()
if x.ndim == 0:
x = x.unsqueeze(0)
ndim = x.ndim
if self.stats is None:
self.stats = [ dict() for _ in range(ndim) ]
for dim in range(ndim):
this_dim_stats = self.stats[dim]
if ndim > 1:
stats_types = ["abs", "positive", "value", "rms"]
if x.shape[dim] <= self.opts.max_eig_dim:
stats_types.append("eigs")
else:
stats_types = ["value", "abs"]
this_dict = self.stats[dim]
for stats_type in stats_types:
stats, count = get_tensor_stats(x, dim, stats_type)
if not stats_type in this_dim_stats:
this_dim_stats[stats_type] = [] # list of TensorAndCount
done = False
if this_dim_stats[stats_type] is None:
# we can reach here if we detected for stats_type "eigs" that
# where was more than one different size for this dim. Then we
# disable accumulating this stats type, as it uses too much memory.
continue
for s in this_dim_stats[stats_type]:
if s.tensor.shape == stats.shape:
s.tensor += stats
s.count += count
done = True
break
if not done:
if this_dim_stats[stats_type] != [] and stats_type == "eigs":
# >1 size encountered on this dim, e.g. it's a batch or time dimension,
# don't accumulat "eigs" stats type, it uses too much memory
this_dim_stats[stats_type] = None
else:
this_dim_stats[stats_type].append(TensorAndCount(stats, count))
def print_diagnostics(self):
"""Print diagnostics for each dimension of the tensor."""
for dim, this_dim_stats in enumerate(self.stats):
for stats_type, stats_list in this_dim_stats.items():
# stats_type could be "rms", "value", "abs", "eigs", "positive".
# "value" could be a list of TensorAndCount, or None
if stats_list is None:
assert stats_type == "eigs"
continue
if stats_type == "eigs":
try:
stats = torch.stack(stats).sum(dim=0)
except: # noqa
return ""
count = sum(counts)
stats = stats / count
assert len(stats_list) == 1
stats = stats_list[0].tensor / stats_list[0].count
try:
eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
@ -143,20 +199,19 @@ def get_diagnostics_for_dim(
eigs = torch.linalg.eigvals(stats)
stats = eigs.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance
elif sizes_same:
stats = torch.stack(stats).sum(dim=0)
count = sum(counts)
stats = stats / count
elif len(stats_list) == 1:
stats = stats_list[0].tensor / stats_list[0].count
else:
stats = [x[0] / x[1] for x in stats_and_counts]
stats = torch.cat(stats, dim=0)
stats = torch.cat([x.tensor / x.count for x in stats_list], dim=0)
if stats_type == "rms":
# we stored the square; after aggregation we need to take sqrt.
stats = stats.sqrt()
# if `summarize` we print percentiles of the stats; else,
# we print out individual elements.
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
if summarize:
summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(stats.numel())
if summarize: # usually `summarize` will be true
# print out percentiles.
stats = stats.sort()[0]
num_percentiles = 10
@ -178,123 +233,18 @@ def get_diagnostics_for_dim(
# speaking in an approximate way, how much of that largest eigenvalue
# can be attributed to the mean of the distribution.
norm = (stats ** 2).sum().sqrt().item()
mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item()
ans += f", norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}"
else:
ans += f", norm={norm:.2g}"
mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item()
ans += f", mean={mean:.2g}, rms={rms:.2g}"
return ans
# OK, "ans" contains the actual stats, e.g.
# ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"
def print_diagnostics_for_dim(
name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions
):
"""This function prints diagnostics for a dimension of a tensor.
sizes = [x.tensor.shape[0] for x in stats_list]
size_str = f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
print(f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}")
Args:
name:
The tensor name.
dim:
The dimension to analyze, with 0 <= dim < tensors[0].ndim.
tensors:
List of cached tensors to get the stats.
options:
Options object.
"""
ndim = tensors[0].ndim
if ndim > 1:
stats_types = ["abs", "positive", "value", "rms"]
if tensors[0].shape[dim] <= options.max_eig_dim:
stats_types.append("eigs")
else:
stats_types = ["value", "abs"]
for stats_type in stats_types:
sizes = [x.shape[dim] for x in tensors]
sizes_same = all([x == sizes[0] for x in sizes])
s = get_diagnostics_for_dim(
dim, tensors, options, sizes_same, stats_type
)
if s == "":
continue
min_size = min(sizes)
max_size = max(sizes)
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
# stats_type will be "abs" or "positive".
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
class TensorDiagnostic(object):
"""This class is not directly used by the user, it is responsible for
collecting diagnostics for a single parameter tensor of a torch.nn.Module.
Args:
opts:
Options object.
name:
The tensor name.
"""
def __init__(self, opts: TensorDiagnosticOptions, name: str):
self.name = name
self.opts = opts
# A list to cache the tensors.
self.saved_tensors = []
def accumulate(self, x):
"""Accumulate tensors."""
if isinstance(x, Tuple):
x = x[0]
if not isinstance(x, Tensor):
return
if x.device == torch.device("cpu"):
x = x.detach().clone()
else:
x = x.detach().to("cpu", non_blocking=True)
self.saved_tensors.append(x)
num = len(self.saved_tensors)
if num & (num - 1) == 0: # power of 2..
self._limit_memory()
def _limit_memory(self):
"""Only keep the newly cached tensors to limit memory."""
if len(self.saved_tensors) > 1024:
self.saved_tensors = self.saved_tensors[-1024:]
return
tot_mem = 0.0
for i in reversed(range(len(self.saved_tensors))):
tot_mem += (
self.saved_tensors[i].numel()
* self.saved_tensors[i].element_size()
)
if tot_mem > self.opts.memory_limit:
self.saved_tensors = self.saved_tensors[i:]
return
def print_diagnostics(self):
"""Print diagnostics for each dimension of the tensor."""
if len(self.saved_tensors) == 0:
print("{name}: no stats".format(name=self.name))
return
if self.saved_tensors[0].ndim == 0:
# Ensure there is at least one dim.
self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors]
try:
device = torch.device("cuda")
except: # noqa
device = torch.device("cpu")
ndim = self.saved_tensors[0].ndim
tensors = [x.to(device) for x in self.saved_tensors]
for dim in range(ndim):
print_diagnostics_for_dim(self.name, dim, tensors, self.opts)
class ModelDiagnostic(object):