Some cleanups

This commit is contained in:
Daniel Povey 2022-04-04 13:37:10 +08:00
parent 0fd0828f79
commit 99e9d6c4b8
3 changed files with 5 additions and 757 deletions

View File

@ -17,8 +17,6 @@
import torch
import torch.nn as nn
from torch import Tensor
from typing import Tuple
class Conv2dSubsampling(nn.Module):
@ -44,27 +42,16 @@ class Conv2dSubsampling(nn.Module):
assert idim >= 7
super().__init__()
self.conv = nn.Sequential(
ScaledConv2d(
nn.Conv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
nn.ReLU(),
nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
nn.ReLU(),
)
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
# set learn_eps=False because out_norm is preceded by `out`, and `out`
# itself has learned scale, so the extra degree of freedom is not
# needed.
self.out_norm = BasicNorm(odim, learn_eps=False)
# constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(channel_dim=-1,
min_positive=0.45,
max_positive=0.55)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
@ -83,8 +70,6 @@ class Conv2dSubsampling(nn.Module):
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out_norm(x)
x = self.out_balancer(x)
return x
@ -174,400 +159,3 @@ class VggSubsampling(nn.Module):
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
return x
class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor,
channel_dim: int,
min_positive: float, # e.g. 0.05
max_positive: float, # e.g. 0.95
max_factor: float, # e.g. 0.01
min_abs: float, # e.g. 0.2
max_abs: float, # e.g. 100.0
) -> Tensor:
if x.requires_grad:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
xgt0 = x > 0
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive)
if min_positive != 0.0 else 0.0)
factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0))
if max_positive != 1.0 else 0.0)
factor = factor1 + factor2
if isinstance(factor, float):
factor = torch.zeros_like(proportion_positive)
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
below_threshold = (mean_abs < min_abs)
above_threshold = (mean_abs > max_abs)
ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold)
ctx.max_factor = max_factor
ctx.sum_dims = sum_dims
return x
@staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]:
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
dtype = x_grad.dtype
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) *
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0))
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
return x_grad - neg_delta_grad, None, None, None, None, None, None
class BasicNorm(torch.nn.Module):
"""
This is intended to be a simpler, and hopefully cheaper, replacement for
LayerNorm. The observation this is based on, is that Transformer-type
networks, especially with pre-norm, sometimes seem to set one of the
feature dimensions to a large constant value (e.g. 50), which "defeats"
the LayerNorm because the output magnitude is then not strongly dependent
on the other (useful) features. Presumably the weight and bias of the
LayerNorm are required to allow it to do this.
So the idea is to introduce this large constant value as an explicit
parameter, that takes the role of the "eps" in LayerNorm, so the network
doesn't have to do this trick. We make the "eps" learnable.
Args:
num_channels: the number of channels, e.g. 512.
channel_dim: the axis/dimension corresponding to the channel,
interprted as an offset from the input's ndim if negative.
shis is NOT the num_channels; it should typically be one of
{-2, -1, 0, 1, 2, 3}.
eps: the initial "epsilon" that we add as ballast in:
scale = ((input_vec**2).mean() + epsilon)**-0.5
Note: our epsilon is actually large, but we keep the name
to indicate the connection with conventional LayerNorm.
learn_eps: if true, we learn epsilon; if false, we keep it
at the initial value.
eps_speed: a constant that determines how fast "eps" learns;
with Adam and variants, this should probably be >= 1,
e.g. 5.0. For SGD and variants, probably a value less than one,
like 0.1, would be suitable, to prevent instability.
"""
def __init__(self,
num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation.
eps: float = 0.25,
learn_eps: bool = True,
eps_speed: float = 5.0):
super(BasicNorm, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.eps_speed = eps_speed
if learn_eps:
self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach())
else:
self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach())
def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) +
(self.eps * self.eps_speed).exp()) ** -0.5
return x * scales
class ScaledLinear(nn.Linear):
"""
A modified version of nn.Linear where the parameters are scaled before
use, via:
weight = self.weight * (self.weight_scale * self.scale_speed).exp()
bias = self.bias * (self.bias_scale * self.scale_speed).exp()
Args:
Accepts the standard args and kwargs that nn.Linear accepts
e.g. in_features, out_features, bias=False.
scale_speed: a factor that affects how fast the weight_scale
and bias_scale learn; this value is suitable for Adam-type
optimizers.
initial_scale: you can override this if you want to increase
or decrease the initial magnitude of the module's output
(affects the initialization of weight_scale and bias_scale).
Another option, if you want to do something like this, is
to re-initialize the parameters.
Note: it uses the default initialization for the weight and bias,
inherited from nn.Linear. For modules with small fan-in, this
may be larger than optimal.
"""
def __init__(self, *args,
scale_speed: float = 5.0,
initial_scale: float = 1.0,
**kwargs):
super(ScaledLinear, self).__init__(*args, **kwargs)
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
self.scale_speed = scale_speed
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter('bias_scale', None)
self._reset_parameters() # Overrides the reset_parameters in nn.Linear
def _reset_parameters(self):
std = 0.05
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed)
def get_weight(self):
return self.weight * (self.weight_scale * self.scale_speed).exp()
def get_bias(self):
return (None if self.bias is None else
self.bias * (self.bias_scale * self.scale_speed).exp())
def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear(input, self.get_weight(),
self.get_bias())
class ScaledConv1d(nn.Conv1d):
def __init__(self, *args, scale_speed = 5.0,
initial_scale=1.0, **kwargs):
super(ScaledConv1d, self).__init__(*args, **kwargs)
self.scale_speed = scale_speed
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter('bias_scale', None)
self._reset_parameters() # Overrides the reset_parameters in base class
def _reset_parameters(self):
std = 0.05
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed)
def get_weight(self):
return self.weight * (self.weight_scale * self.scale_speed).exp()
def get_bias(self):
return (None if self.bias is None else
self.bias * (self.bias_scale * self.scale_speed).exp())
def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional
if self.padding_mode != 'zeros':
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
self.get_weight(), self.get_bias(), self.stride,
_single(0), self.dilation, self.groups)
return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride,
self.padding, self.dilation, self.groups)
class ScaledConv2d(nn.Conv2d):
def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs):
super(ScaledConv2d, self).__init__(*args, **kwargs)
self.scale_speed = scale_speed
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter('bias_scale', None)
self._reset_parameters() # Overrides the reset_parameters in base class
def _reset_parameters(self):
std = 0.05
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed)
def get_weight(self):
return self.weight * (self.weight_scale * self.scale_speed).exp()
def get_bias(self):
return (None if self.bias is None else
self.bias * (self.bias_scale * self.scale_speed).exp())
def _conv_forward(self, input, weight):
F = torch.nn.functional
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.get_bias(), self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.get_bias(), self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.get_weight())
class ActivationBalancer(torch.nn.Module):
"""
Modifies the backpropped derivatives of a function to try to encourage, for
each channel, that it is positive at least a proportion `threshold` of the
time. It does this by multiplying negative derivative values by up to
(1+max_factor), and positive derivative values by up to (1-max_factor),
interpolated from 1 at the threshold to those extremal values when none
of the inputs are positive.
Args:
channel_dim: the dimension/axis corresponding to the channel, e.g.
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
min_positive: the minimum, per channel, of the proportion of the time
that (x > 0), below which we start to modify the derivatives.
max_positive: the maximum, per channel, of the proportion of the time
that (x > 0), below which we start to modify the derivatives.
max_factor: the maximum factor by which we modify the derivatives for
either the sign constraint or the magnitude constraint;
e.g. with max_factor=0.02, the the derivatives would be multiplied by
values in the range [0.98..1.02].
min_abs: the minimum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent
this.
max_abs: the maximum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent
this.
"""
def __init__(self, channel_dim: int,
min_positive: float = 0.05,
max_positive: float = 0.95,
max_factor: float = 0.01,
min_abs: float = 0.2,
max_abs: float = 100.0):
super(ActivationBalancer, self).__init__()
self.channel_dim = channel_dim
self.min_positive = min_positive
self.max_positive = max_positive
self.max_factor = max_factor
self.min_abs = min_abs
self.max_abs = max_abs
def forward(self, x: Tensor) -> Tensor:
return ActivationBalancerFunction.apply(x, self.channel_dim,
self.min_positive, self.max_positive,
self.max_factor, self.min_abs,
self.max_abs)
def _double_swish(x: Tensor) -> Tensor:
# double-swish, implemented/approximated as offset-swish
return x * torch.sigmoid(x - 1.0)
class DoubleSwishFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
ctx.save_for_backward(x.detach())
return _double_swish(x)
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
# TODO: can make this more efficient.
x, = ctx.saved_tensors
x.requires_grad = True
with torch.enable_grad():
y = _double_swish(x)
y.backward(gradient=y_grad)
return x.grad
class DoubleSwish(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1).
"""
return DoubleSwishFunction.apply(x)
def _test_deriv_balancer_sign():
channel_dim = 0
probs = torch.arange(0, 1, 0.01)
N = 1000
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
x = x.detach()
x.requires_grad = True
m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
max_factor=0.2, min_abs=0.0)
y_grad = torch.sign(torch.randn(probs.numel(), N))
y = m(x)
y.backward(gradient=y_grad)
print("_test_deriv_balancer_sign: x = ", x)
print("_test_deriv_balancer_sign: y grad = ", y_grad)
print("_test_deriv_balancer_sign: x grad = ", x.grad)
def _test_deriv_balancer_magnitude():
channel_dim = 0
magnitudes = torch.arange(0, 1, 0.01)
N = 1000
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
x = x.detach()
x.requires_grad = True
m = ActivationBalancer(channel_dim=0,
min_positive=0.0, max_positive=1.0,
max_factor=0.2,
min_abs=0.2, max_abs=0.8)
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
y = m(x)
y.backward(gradient=y_grad)
print("_test_deriv_balancer_magnitude: x = ", x)
print("_test_deriv_balancer_magnitude: y grad = ", y_grad)
print("_test_deriv_balancer_magnitude: x grad = ", x.grad)
def _test_basic_norm():
num_channels = 128
m = BasicNorm(num_channels=num_channels, channel_dim=1)
x = torch.randn(500, num_channels)
y = m(x)
assert y.shape == x.shape
x_rms = (x**2).mean().sqrt()
y_rms = (y**2).mean().sqrt()
print("x rms = ", x_rms)
print("y rms = ", y_rms)
assert y_rms < x_rms
assert y_rms > 0.5 * x_rms
if __name__ == '__main__':
_test_deriv_balancer_sign()
_test_deriv_balancer_magnitude()
_test_basic_norm()

View File

@ -22,8 +22,6 @@ import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse.utils import fix_random_seed
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest

View File

@ -1,338 +0,0 @@
import torch
from torch import Tensor
from torch import nn
import math
import random
from typing import Tuple, List
class TensorDiagnosticOptions(object):
"""
Options object for tensor diagnostics:
Args:
memory_limit: the maximum number of bytes we store per tensor (limits how many copies
of the tensor we cache).
max_eig_dim: the maximum dimension for which we print out eigenvalues
(limited for speed reasons).
"""
def __init__(self,
memory_limit: int = (2 ** 20),
max_eig_dim: int = 512):
self.memory_limit = memory_limit
self.max_eig_dim = max_eig_dim
def dim_is_summarized(self, size: int):
return size > 10 and size != 31
def get_tensor_stats(x: Tensor, dim: int,
stats_type: str) -> Tuple[Tensor, int]:
"""
Returns the specified transformation of the Tensor (either x or x.abs()
or (x > 0), summed over all but the index `dim`.
Args:
x: Tensor, tensor to be analyzed
dim: dimension with 0 <= dim < x.ndim
stats_type:
"abs" -> take abs() before summing
"positive" -> take (x > 0) before summing
"rms" -> square before summing, we'll take sqrt later
"value -> just sum x itself
Returns (stats, count)
where stats is a Tensor of shape (x.shape[dim],), and the count
is an integer saying how many items were counted in each element
of stats.
"""
count = x.numel() // x.shape[dim]
if stats_type == "eigs":
x = x.transpose(dim, -1)
x = x.reshape(-1, x.shape[-1])
# shape of returned tensor: (s, s) where s is size of dimension `dim` of original x.
return torch.matmul(x.transpose(0, 1), x), count
elif stats_type == "abs":
x = x.abs()
elif stats_type == "rms":
x = x ** 2
elif stats_type == "positive":
x = (x > 0).to(dtype=torch.float)
else:
assert stats_type == "value"
sum_dims = [ d for d in range(x.ndim) if d != dim ]
if len(sum_dims) > 0:
x = torch.sum(x, dim=sum_dims)
x = x.flatten()
return x, count
def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
options: TensorDiagnosticOptions,
sizes_same: bool,
stats_type: str):
"""
This function gets diagnostics for a dimension of a module.
Args:
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
options: options object
sizes_same: true if all the tensor sizes are the same on this dimension
stats_type: either "abs" or "positive" or "eigs" or "value",
imdictates the type of stats
we accumulate, abs is mean absolute value, "positive"
is proportion of positive to nonnegative values, "eigs"
is eigenvalues after doing outer product on this dim, sum
over all other dimes.
Returns:
Diagnostic as a string, either percentiles or the actual values,
see the code. Will return the empty string if the diagnostics did
not make sense to print out for this dimension, e.g. dimension
mismatch and stats_type == "eigs"
"""
# stats_and_counts is a list of pair (Tensor, int)
stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ]
stats = [ x[0] for x in stats_and_counts ]
counts = [ x[1] for x in stats_and_counts ]
if stats_type == "eigs":
try:
stats = torch.stack(stats).sum(dim=0)
except:
return ''
count = sum(counts)
stats = stats / count
stats, _ = torch.symeig(stats)
stats = stats.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance
elif sizes_same:
stats = torch.stack(stats).sum(dim=0)
count = sum(counts)
stats = stats / count
else:
stats = [ x[0] / x[1] for x in stats_and_counts ]
stats = torch.cat(stats, dim=0)
if stats_type == 'rms':
stats = stats.sqrt()
# if `summarize` we print percentiles of the stats; else,
# we print out individual elements.
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
if summarize:
# print out percentiles.
stats = stats.sort()[0]
num_percentiles = 10
size = stats.numel()
percentiles = []
for i in range(num_percentiles + 1):
index = (i * (size - 1)) // num_percentiles
percentiles.append(stats[index].item())
percentiles = [ '%.2g' % x for x in percentiles ]
percentiles = ' '.join(percentiles)
ans = f'percentiles: [{percentiles}]'
else:
ans = stats.tolist()
ans = [ '%.2g' % x for x in ans ]
ans = '[' + ' '.join(ans) + ']'
if stats_type == "value":
# This norm is useful because it is strictly less than the largest
# sqrt(eigenvalue) of the variance, which we print out, and shows,
# speaking in an approximate way, how much of that largest eigenvalue
# can be attributed to the mean of the distribution.
norm = (stats ** 2).sum().sqrt().item()
mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item()
ans += f', norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}'
else:
mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item()
ans += f', mean={mean:.2g}, rms={rms:.2g}'
return ans
def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
options: TensorDiagnosticOptions):
ndim = tensors[0].ndim
if ndim > 1:
stats_types = ["abs", "positive", "value", "rms"]
if tensors[0].shape[dim] <= options.max_eig_dim:
stats_types.append("eigs")
else:
stats_types = [ "value", "abs" ]
for stats_type in stats_types:
sizes = [ x.shape[dim] for x in tensors ]
sizes_same = all([ x == sizes[0] for x in sizes ])
s = get_diagnostics_for_dim(dim, tensors,
options, sizes_same,
stats_type)
if s == '':
continue
min_size = min(sizes)
max_size = max(sizes)
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
# stats_type will be "abs" or "positive".
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
class TensorDiagnostic(object):
"""
This class is not directly used by the user, it is responsible for collecting
diagnostics for a single parameter tensor of a torch.Module.
"""
def __init__(self,
opts: TensorDiagnosticOptions,
name: str):
self.name = name
self.opts = opts
self.saved_tensors = []
def accumulate(self, x):
if isinstance(x, Tuple):
x = x[0]
if not isinstance(x, Tensor):
return
if x.device == torch.device('cpu'):
x = x.detach().clone()
else:
x = x.detach().to('cpu', non_blocking=True)
self.saved_tensors.append(x)
l = len(self.saved_tensors)
if l & (l - 1) == 0: # power of 2..
self._limit_memory()
def _limit_memory(self):
if len(self.saved_tensors) > 1024:
self.saved_tensors = self.saved_tensors[-1024:]
return
tot_mem = 0.0
for i in reversed(range(len(self.saved_tensors))):
tot_mem += self.saved_tensors[i].numel() * self.saved_tensors[i].element_size()
if tot_mem > self.opts.memory_limit:
self.saved_tensors = self.saved_tensors[i:]
return
def print_diagnostics(self):
if len(self.saved_tensors) == 0:
print("{name}: no stats".format(name=self.name))
return
if self.saved_tensors[0].ndim == 0:
# ensure there is at least one dim.
self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ]
try:
device = torch.device('cuda')
torch.ones(1, 1, device)
except:
device = torch.device('cpu')
ndim = self.saved_tensors[0].ndim
tensors = [x.to(device) for x in self.saved_tensors]
for dim in range(ndim):
print_diagnostics_for_dim(self.name, dim,
tensors,
self.opts)
class ModelDiagnostic(object):
def __init__(self, opts: TensorDiagnosticOptions = TensorDiagnosticOptions()):
self.diagnostics = dict()
self.opts = opts
def __getitem__(self, name: str):
if name not in self.diagnostics:
self.diagnostics[name] = TensorDiagnostic(self.opts, name)
return self.diagnostics[name]
def print_diagnostics(self):
for k in sorted(self.diagnostics.keys()):
self.diagnostics[k].print_diagnostics()
def attach_diagnostics(model: nn.Module,
opts: TensorDiagnosticOptions) -> ModelDiagnostic:
ans = ModelDiagnostic(opts)
for name, module in model.named_modules():
if name == '':
name = "<top-level>"
forward_diagnostic = TensorDiagnostic(opts, name + ".output")
backward_diagnostic = TensorDiagnostic(opts, name + ".grad")
# setting model_diagnostic=ans and n=name below, instead of trying to capture the variables,
# ensures that we use the current values. (matters for name, since
# the variable gets overwritten). these closures don't really capture
# by value, only by "the final value the variable got in the function" :-(
def forward_hook(_module, _input, _output,
_model_diagnostic=ans, _name=name):
if isinstance(_output, Tensor):
_model_diagnostic[f"{_name}.output"].accumulate(_output)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o)
def backward_hook(_module, _input, _output,
_model_diagnostic=ans, _name=name):
if isinstance(_output, Tensor):
_model_diagnostic[f"{_name}.grad"].accumulate(_output)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o)
module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)
for name, parameter in model.named_parameters():
def param_backward_hook(grad,
_parameter=parameter,
_model_diagnostic=ans,
_name=name):
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
parameter.register_hook(param_backward_hook)
return ans
def _test_tensor_diagnostic():
opts = TensorDiagnosticOptions(2**20, 512)
diagnostic = TensorDiagnostic(opts, "foo")
for _ in range(10):
diagnostic.accumulate(torch.randn(50, 100) * 10.0)
diagnostic.print_diagnostics()
model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
diagnostic = attach_diagnostics(model, opts)
for _ in range(10):
T = random.randint(200, 300)
x = torch.randn(T, 100)
y = model(x)
y.sum().backward()
diagnostic.print_diagnostics()
if __name__ == '__main__':
_test_tensor_diagnostic()
def _test_func():
ans = []
for i in range(10):
x = list()
x.append(i)
def func():
return x
ans.append(func)
return ans