mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Improve diagnostics code memory-wise and accumulate more stats. (#373)
* Update diagnostics, hopefully print more stats. # Conflicts: # egs/librispeech/ASR/pruned_transducer_stateless4b/train.py * Remove memory-limit options arg * Remove unnecessary option for diagnostics code, collect on more batches
This commit is contained in:
parent
f6ce135608
commit
4e23fb2252
@ -689,7 +689,7 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
if params.print_diagnostics and batch_idx == 30:
|
||||
return
|
||||
|
||||
if (
|
||||
@ -831,10 +831,7 @@ def run(rank, world_size, args):
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
|
||||
gigaspeech = GigaSpeechAsrDataModule(args)
|
||||
|
||||
|
@ -695,7 +695,7 @@ def train_one_epoch(
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
if params.print_diagnostics and batch_idx == 30:
|
||||
return
|
||||
|
||||
if (
|
||||
@ -839,10 +839,7 @@ def run(rank, world_size, args):
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
|
@ -767,7 +767,7 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
if params.print_diagnostics and batch_idx == 30:
|
||||
return
|
||||
|
||||
if (
|
||||
@ -938,10 +938,7 @@ def run(rank, world_size, args):
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
|
||||
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
|
||||
|
||||
|
@ -724,7 +724,7 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
if params.print_diagnostics and batch_idx == 30:
|
||||
return
|
||||
|
||||
if (
|
||||
@ -888,10 +888,7 @@ def run(rank, world_size, args):
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
|
@ -523,7 +523,7 @@ def train_one_epoch(
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
if params.print_diagnostics and batch_idx == 30:
|
||||
return
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
@ -635,10 +635,7 @@ def run(rank, world_size, args):
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
|
@ -511,7 +511,7 @@ def train_one_epoch(
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
if params.print_diagnostics and batch_idx == 30:
|
||||
return
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
@ -623,10 +623,7 @@ def run(rank, world_size, args):
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
|
@ -690,7 +690,7 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
if params.print_diagnostics and batch_idx == 30:
|
||||
return
|
||||
|
||||
if (
|
||||
@ -832,10 +832,7 @@ def run(rank, world_size, args):
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
|
||||
spgispeech = SPGISpeechAsrDataModule(args)
|
||||
|
||||
|
@ -19,7 +19,7 @@
|
||||
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -28,16 +28,12 @@ class TensorDiagnosticOptions(object):
|
||||
"""Options object for tensor diagnostics:
|
||||
|
||||
Args:
|
||||
memory_limit:
|
||||
The maximum number of bytes per tensor
|
||||
(limits how many copies of the tensor we cache).
|
||||
max_eig_dim:
|
||||
The maximum dimension for which we print out eigenvalues
|
||||
(limited for speed reasons).
|
||||
"""
|
||||
|
||||
def __init__(self, memory_limit: int = (2 ** 20), max_eig_dim: int = 512):
|
||||
self.memory_limit = memory_limit
|
||||
def __init__(self, max_eig_dim: int = 512):
|
||||
self.max_eig_dim = max_eig_dim
|
||||
|
||||
def dim_is_summarized(self, size: int):
|
||||
@ -94,138 +90,12 @@ def get_tensor_stats(
|
||||
return x, count
|
||||
|
||||
|
||||
def get_diagnostics_for_dim(
|
||||
dim: int,
|
||||
tensors: List[Tensor],
|
||||
options: TensorDiagnosticOptions,
|
||||
sizes_same: bool,
|
||||
stats_type: str,
|
||||
) -> str:
|
||||
"""
|
||||
This function gets diagnostics for a dimension of a module.
|
||||
|
||||
Args:
|
||||
dim:
|
||||
the dimension to analyze, with 0 <= dim < tensors[0].ndim
|
||||
options:
|
||||
options object
|
||||
sizes_same:
|
||||
True if all the tensor sizes are the same on this dimension
|
||||
stats_type: either "abs" or "positive" or "eigs" or "value",
|
||||
imdictates the type of stats we accumulate, abs is mean absolute
|
||||
value, "positive" is proportion of positive to nonnegative values,
|
||||
"eigs" is eigenvalues after doing outer product on this dim, sum
|
||||
over all other dimes.
|
||||
Returns:
|
||||
Diagnostic as a string, either percentiles or the actual values,
|
||||
see the code. Will return the empty string if the diagnostics did
|
||||
not make sense to print out for this dimension, e.g. dimension
|
||||
mismatch and stats_type == "eigs".
|
||||
"""
|
||||
|
||||
# stats_and_counts is a list of pair (Tensor, int)
|
||||
stats_and_counts = [get_tensor_stats(x, dim, stats_type) for x in tensors]
|
||||
stats = [x[0] for x in stats_and_counts]
|
||||
counts = [x[1] for x in stats_and_counts]
|
||||
|
||||
if stats_type == "eigs":
|
||||
try:
|
||||
stats = torch.stack(stats).sum(dim=0)
|
||||
except: # noqa
|
||||
return ""
|
||||
count = sum(counts)
|
||||
stats = stats / count
|
||||
try:
|
||||
eigs, _ = torch.symeig(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
except: # noqa
|
||||
print("Error getting eigenvalues, trying another method.")
|
||||
eigs = torch.linalg.eigvals(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||
elif sizes_same:
|
||||
stats = torch.stack(stats).sum(dim=0)
|
||||
count = sum(counts)
|
||||
stats = stats / count
|
||||
else:
|
||||
stats = [x[0] / x[1] for x in stats_and_counts]
|
||||
stats = torch.cat(stats, dim=0)
|
||||
if stats_type == "rms":
|
||||
stats = stats.sqrt()
|
||||
|
||||
# if `summarize` we print percentiles of the stats; else,
|
||||
# we print out individual elements.
|
||||
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
|
||||
if summarize:
|
||||
# print out percentiles.
|
||||
stats = stats.sort()[0]
|
||||
num_percentiles = 10
|
||||
size = stats.numel()
|
||||
percentiles = []
|
||||
for i in range(num_percentiles + 1):
|
||||
index = (i * (size - 1)) // num_percentiles
|
||||
percentiles.append(stats[index].item())
|
||||
percentiles = ["%.2g" % x for x in percentiles]
|
||||
percentiles = " ".join(percentiles)
|
||||
ans = f"percentiles: [{percentiles}]"
|
||||
else:
|
||||
ans = stats.tolist()
|
||||
ans = ["%.2g" % x for x in ans]
|
||||
ans = "[" + " ".join(ans) + "]"
|
||||
if stats_type == "value":
|
||||
# This norm is useful because it is strictly less than the largest
|
||||
# sqrt(eigenvalue) of the variance, which we print out, and shows,
|
||||
# speaking in an approximate way, how much of that largest eigenvalue
|
||||
# can be attributed to the mean of the distribution.
|
||||
norm = (stats ** 2).sum().sqrt().item()
|
||||
mean = stats.mean().item()
|
||||
rms = (stats ** 2).mean().sqrt().item()
|
||||
ans += f", norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}"
|
||||
else:
|
||||
mean = stats.mean().item()
|
||||
rms = (stats ** 2).mean().sqrt().item()
|
||||
ans += f", mean={mean:.2g}, rms={rms:.2g}"
|
||||
return ans
|
||||
|
||||
|
||||
def print_diagnostics_for_dim(
|
||||
name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions
|
||||
):
|
||||
"""This function prints diagnostics for a dimension of a tensor.
|
||||
|
||||
Args:
|
||||
name:
|
||||
The tensor name.
|
||||
dim:
|
||||
The dimension to analyze, with 0 <= dim < tensors[0].ndim.
|
||||
tensors:
|
||||
List of cached tensors to get the stats.
|
||||
options:
|
||||
Options object.
|
||||
"""
|
||||
|
||||
ndim = tensors[0].ndim
|
||||
if ndim > 1:
|
||||
stats_types = ["abs", "positive", "value", "rms"]
|
||||
if tensors[0].shape[dim] <= options.max_eig_dim:
|
||||
stats_types.append("eigs")
|
||||
else:
|
||||
stats_types = ["value", "abs"]
|
||||
|
||||
for stats_type in stats_types:
|
||||
sizes = [x.shape[dim] for x in tensors]
|
||||
sizes_same = all([x == sizes[0] for x in sizes])
|
||||
s = get_diagnostics_for_dim(
|
||||
dim, tensors, options, sizes_same, stats_type
|
||||
)
|
||||
if s == "":
|
||||
continue
|
||||
|
||||
min_size = min(sizes)
|
||||
max_size = max(sizes)
|
||||
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
|
||||
# stats_type will be "abs" or "positive".
|
||||
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
|
||||
@dataclass
|
||||
class TensorAndCount:
|
||||
tensor: Tensor
|
||||
count: int
|
||||
|
||||
|
||||
class TensorDiagnostic(object):
|
||||
@ -238,12 +108,23 @@ class TensorDiagnostic(object):
|
||||
name:
|
||||
The tensor name.
|
||||
"""
|
||||
|
||||
def __init__(self, opts: TensorDiagnosticOptions, name: str):
|
||||
self.name = name
|
||||
self.opts = opts
|
||||
# A list to cache the tensors.
|
||||
self.saved_tensors = []
|
||||
|
||||
|
||||
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
|
||||
|
||||
# the keys into self.stats[dim] are strings, whose values can be
|
||||
# "abs", "value", "positive", "rms", "value".
|
||||
# The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount,
|
||||
# containing a tensor and its associated count (which is the sum of the other dims
|
||||
# that we aggregated over, e.g. the number of frames and/or batch elements and/or
|
||||
# channels.
|
||||
# ... we actually accumulate the Tensors / counts any time we have the same-dim tensor,
|
||||
# only adding a new element to the list if there was a different dim.
|
||||
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
|
||||
|
||||
|
||||
def accumulate(self, x):
|
||||
"""Accumulate tensors."""
|
||||
@ -251,50 +132,115 @@ class TensorDiagnostic(object):
|
||||
x = x[0]
|
||||
if not isinstance(x, Tensor):
|
||||
return
|
||||
if x.device == torch.device("cpu"):
|
||||
x = x.detach().clone()
|
||||
else:
|
||||
x = x.detach().to("cpu", non_blocking=True)
|
||||
self.saved_tensors.append(x)
|
||||
num = len(self.saved_tensors)
|
||||
if num & (num - 1) == 0: # power of 2..
|
||||
self._limit_memory()
|
||||
x = x.detach().clone()
|
||||
if x.ndim == 0:
|
||||
x = x.unsqueeze(0)
|
||||
ndim = x.ndim
|
||||
if self.stats is None:
|
||||
self.stats = [ dict() for _ in range(ndim) ]
|
||||
|
||||
def _limit_memory(self):
|
||||
"""Only keep the newly cached tensors to limit memory."""
|
||||
if len(self.saved_tensors) > 1024:
|
||||
self.saved_tensors = self.saved_tensors[-1024:]
|
||||
return
|
||||
for dim in range(ndim):
|
||||
this_dim_stats = self.stats[dim]
|
||||
if ndim > 1:
|
||||
stats_types = ["abs", "positive", "value", "rms"]
|
||||
if x.shape[dim] <= self.opts.max_eig_dim:
|
||||
stats_types.append("eigs")
|
||||
else:
|
||||
stats_types = ["value", "abs"]
|
||||
this_dict = self.stats[dim]
|
||||
for stats_type in stats_types:
|
||||
stats, count = get_tensor_stats(x, dim, stats_type)
|
||||
if not stats_type in this_dim_stats:
|
||||
this_dim_stats[stats_type] = [] # list of TensorAndCount
|
||||
|
||||
done = False
|
||||
if this_dim_stats[stats_type] is None:
|
||||
# we can reach here if we detected for stats_type "eigs" that
|
||||
# where was more than one different size for this dim. Then we
|
||||
# disable accumulating this stats type, as it uses too much memory.
|
||||
continue
|
||||
for s in this_dim_stats[stats_type]:
|
||||
if s.tensor.shape == stats.shape:
|
||||
s.tensor += stats
|
||||
s.count += count
|
||||
done = True
|
||||
break
|
||||
if not done:
|
||||
if this_dim_stats[stats_type] != [] and stats_type == "eigs":
|
||||
# >1 size encountered on this dim, e.g. it's a batch or time dimension,
|
||||
# don't accumulat "eigs" stats type, it uses too much memory
|
||||
this_dim_stats[stats_type] = None
|
||||
else:
|
||||
this_dim_stats[stats_type].append(TensorAndCount(stats, count))
|
||||
|
||||
tot_mem = 0.0
|
||||
for i in reversed(range(len(self.saved_tensors))):
|
||||
tot_mem += (
|
||||
self.saved_tensors[i].numel()
|
||||
* self.saved_tensors[i].element_size()
|
||||
)
|
||||
if tot_mem > self.opts.memory_limit:
|
||||
self.saved_tensors = self.saved_tensors[i:]
|
||||
return
|
||||
|
||||
def print_diagnostics(self):
|
||||
"""Print diagnostics for each dimension of the tensor."""
|
||||
if len(self.saved_tensors) == 0:
|
||||
print("{name}: no stats".format(name=self.name))
|
||||
return
|
||||
for dim, this_dim_stats in enumerate(self.stats):
|
||||
for stats_type, stats_list in this_dim_stats.items():
|
||||
# stats_type could be "rms", "value", "abs", "eigs", "positive".
|
||||
# "value" could be a list of TensorAndCount, or None
|
||||
if stats_list is None:
|
||||
assert stats_type == "eigs"
|
||||
continue
|
||||
|
||||
if self.saved_tensors[0].ndim == 0:
|
||||
# Ensure there is at least one dim.
|
||||
self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors]
|
||||
if stats_type == "eigs":
|
||||
assert len(stats_list) == 1
|
||||
stats = stats_list[0].tensor / stats_list[0].count
|
||||
try:
|
||||
eigs, _ = torch.symeig(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
except: # noqa
|
||||
print("Error getting eigenvalues, trying another method.")
|
||||
eigs = torch.linalg.eigvals(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||
elif len(stats_list) == 1:
|
||||
stats = stats_list[0].tensor / stats_list[0].count
|
||||
else:
|
||||
stats = torch.cat([x.tensor / x.count for x in stats_list], dim=0)
|
||||
|
||||
try:
|
||||
device = torch.device("cuda")
|
||||
except: # noqa
|
||||
device = torch.device("cpu")
|
||||
if stats_type == "rms":
|
||||
# we stored the square; after aggregation we need to take sqrt.
|
||||
stats = stats.sqrt()
|
||||
|
||||
# if `summarize` we print percentiles of the stats; else,
|
||||
# we print out individual elements.
|
||||
summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(stats.numel())
|
||||
if summarize: # usually `summarize` will be true
|
||||
# print out percentiles.
|
||||
stats = stats.sort()[0]
|
||||
num_percentiles = 10
|
||||
size = stats.numel()
|
||||
percentiles = []
|
||||
for i in range(num_percentiles + 1):
|
||||
index = (i * (size - 1)) // num_percentiles
|
||||
percentiles.append(stats[index].item())
|
||||
percentiles = ["%.2g" % x for x in percentiles]
|
||||
percentiles = " ".join(percentiles)
|
||||
ans = f"percentiles: [{percentiles}]"
|
||||
else:
|
||||
ans = stats.tolist()
|
||||
ans = ["%.2g" % x for x in ans]
|
||||
ans = "[" + " ".join(ans) + "]"
|
||||
if stats_type == "value":
|
||||
# This norm is useful because it is strictly less than the largest
|
||||
# sqrt(eigenvalue) of the variance, which we print out, and shows,
|
||||
# speaking in an approximate way, how much of that largest eigenvalue
|
||||
# can be attributed to the mean of the distribution.
|
||||
norm = (stats ** 2).sum().sqrt().item()
|
||||
ans += f", norm={norm:.2g}"
|
||||
mean = stats.mean().item()
|
||||
rms = (stats ** 2).mean().sqrt().item()
|
||||
ans += f", mean={mean:.2g}, rms={rms:.2g}"
|
||||
|
||||
# OK, "ans" contains the actual stats, e.g.
|
||||
# ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"
|
||||
|
||||
sizes = [x.tensor.shape[0] for x in stats_list]
|
||||
size_str = f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
|
||||
print(f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}")
|
||||
|
||||
ndim = self.saved_tensors[0].ndim
|
||||
tensors = [x.to(device) for x in self.saved_tensors]
|
||||
for dim in range(ndim):
|
||||
print_diagnostics_for_dim(self.name, dim, tensors, self.opts)
|
||||
|
||||
|
||||
class ModelDiagnostic(object):
|
||||
|
Loading…
x
Reference in New Issue
Block a user