mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
Some cleanups
This commit is contained in:
parent
0fd0828f79
commit
99e9d6c4b8
@ -17,8 +17,6 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
@ -44,27 +42,16 @@ class Conv2dSubsampling(nn.Module):
|
||||
assert idim >= 7
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
ScaledConv2d(
|
||||
nn.Conv2d(
|
||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||
# itself has learned scale, so the extra degree of freedom is not
|
||||
# needed.
|
||||
self.out_norm = BasicNorm(odim, learn_eps=False)
|
||||
# constrain median of output to be close to zero.
|
||||
self.out_balancer = ActivationBalancer(channel_dim=-1,
|
||||
min_positive=0.45,
|
||||
max_positive=0.55)
|
||||
|
||||
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
@ -83,8 +70,6 @@ class Conv2dSubsampling(nn.Module):
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = self.out_norm(x)
|
||||
x = self.out_balancer(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -174,400 +159,3 @@ class VggSubsampling(nn.Module):
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class ActivationBalancerFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor,
|
||||
channel_dim: int,
|
||||
min_positive: float, # e.g. 0.05
|
||||
max_positive: float, # e.g. 0.95
|
||||
max_factor: float, # e.g. 0.01
|
||||
min_abs: float, # e.g. 0.2
|
||||
max_abs: float, # e.g. 100.0
|
||||
) -> Tensor:
|
||||
if x.requires_grad:
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||
xgt0 = x > 0
|
||||
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
|
||||
factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive)
|
||||
if min_positive != 0.0 else 0.0)
|
||||
factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0))
|
||||
if max_positive != 1.0 else 0.0)
|
||||
factor = factor1 + factor2
|
||||
if isinstance(factor, float):
|
||||
factor = torch.zeros_like(proportion_positive)
|
||||
|
||||
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
||||
below_threshold = (mean_abs < min_abs)
|
||||
above_threshold = (mean_abs > max_abs)
|
||||
|
||||
ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold)
|
||||
ctx.max_factor = max_factor
|
||||
ctx.sum_dims = sum_dims
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]:
|
||||
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
||||
dtype = x_grad.dtype
|
||||
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) *
|
||||
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0))
|
||||
|
||||
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
||||
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
||||
|
||||
|
||||
class BasicNorm(torch.nn.Module):
|
||||
"""
|
||||
This is intended to be a simpler, and hopefully cheaper, replacement for
|
||||
LayerNorm. The observation this is based on, is that Transformer-type
|
||||
networks, especially with pre-norm, sometimes seem to set one of the
|
||||
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
||||
the LayerNorm because the output magnitude is then not strongly dependent
|
||||
on the other (useful) features. Presumably the weight and bias of the
|
||||
LayerNorm are required to allow it to do this.
|
||||
|
||||
So the idea is to introduce this large constant value as an explicit
|
||||
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
||||
doesn't have to do this trick. We make the "eps" learnable.
|
||||
|
||||
Args:
|
||||
num_channels: the number of channels, e.g. 512.
|
||||
channel_dim: the axis/dimension corresponding to the channel,
|
||||
interprted as an offset from the input's ndim if negative.
|
||||
shis is NOT the num_channels; it should typically be one of
|
||||
{-2, -1, 0, 1, 2, 3}.
|
||||
eps: the initial "epsilon" that we add as ballast in:
|
||||
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
||||
Note: our epsilon is actually large, but we keep the name
|
||||
to indicate the connection with conventional LayerNorm.
|
||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||
at the initial value.
|
||||
eps_speed: a constant that determines how fast "eps" learns;
|
||||
with Adam and variants, this should probably be >= 1,
|
||||
e.g. 5.0. For SGD and variants, probably a value less than one,
|
||||
like 0.1, would be suitable, to prevent instability.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_channels: int,
|
||||
channel_dim: int = -1, # CAUTION: see documentation.
|
||||
eps: float = 0.25,
|
||||
learn_eps: bool = True,
|
||||
eps_speed: float = 5.0):
|
||||
super(BasicNorm, self).__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
self.eps_speed = eps_speed
|
||||
if learn_eps:
|
||||
self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach())
|
||||
else:
|
||||
self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach())
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) +
|
||||
(self.eps * self.eps_speed).exp()) ** -0.5
|
||||
return x * scales
|
||||
|
||||
|
||||
|
||||
|
||||
class ScaledLinear(nn.Linear):
|
||||
"""
|
||||
A modified version of nn.Linear where the parameters are scaled before
|
||||
use, via:
|
||||
weight = self.weight * (self.weight_scale * self.scale_speed).exp()
|
||||
bias = self.bias * (self.bias_scale * self.scale_speed).exp()
|
||||
|
||||
Args:
|
||||
Accepts the standard args and kwargs that nn.Linear accepts
|
||||
e.g. in_features, out_features, bias=False.
|
||||
|
||||
scale_speed: a factor that affects how fast the weight_scale
|
||||
and bias_scale learn; this value is suitable for Adam-type
|
||||
optimizers.
|
||||
initial_scale: you can override this if you want to increase
|
||||
or decrease the initial magnitude of the module's output
|
||||
(affects the initialization of weight_scale and bias_scale).
|
||||
Another option, if you want to do something like this, is
|
||||
to re-initialize the parameters.
|
||||
|
||||
Note: it uses the default initialization for the weight and bias,
|
||||
inherited from nn.Linear. For modules with small fan-in, this
|
||||
may be larger than optimal.
|
||||
"""
|
||||
def __init__(self, *args,
|
||||
scale_speed: float = 5.0,
|
||||
initial_scale: float = 1.0,
|
||||
**kwargs):
|
||||
super(ScaledLinear, self).__init__(*args, **kwargs)
|
||||
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
self.scale_speed = scale_speed
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter('bias_scale', None)
|
||||
|
||||
self._reset_parameters() # Overrides the reset_parameters in nn.Linear
|
||||
|
||||
def _reset_parameters(self):
|
||||
std = 0.05
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed)
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight * (self.weight_scale * self.scale_speed).exp()
|
||||
|
||||
def get_bias(self):
|
||||
return (None if self.bias is None else
|
||||
self.bias * (self.bias_scale * self.scale_speed).exp())
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return torch.nn.functional.linear(input, self.get_weight(),
|
||||
self.get_bias())
|
||||
|
||||
|
||||
class ScaledConv1d(nn.Conv1d):
|
||||
def __init__(self, *args, scale_speed = 5.0,
|
||||
initial_scale=1.0, **kwargs):
|
||||
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
||||
self.scale_speed = scale_speed
|
||||
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter('bias_scale', None)
|
||||
self._reset_parameters() # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self):
|
||||
std = 0.05
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed)
|
||||
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight * (self.weight_scale * self.scale_speed).exp()
|
||||
|
||||
def get_bias(self):
|
||||
return (None if self.bias is None else
|
||||
self.bias * (self.bias_scale * self.scale_speed).exp())
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
F = torch.nn.functional
|
||||
if self.padding_mode != 'zeros':
|
||||
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||||
self.get_weight(), self.get_bias(), self.stride,
|
||||
_single(0), self.dilation, self.groups)
|
||||
return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
|
||||
class ScaledConv2d(nn.Conv2d):
|
||||
def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs):
|
||||
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
||||
self.scale_speed = scale_speed
|
||||
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter('bias_scale', None)
|
||||
self._reset_parameters() # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self):
|
||||
std = 0.05
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed)
|
||||
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight * (self.weight_scale * self.scale_speed).exp()
|
||||
|
||||
def get_bias(self):
|
||||
return (None if self.bias is None else
|
||||
self.bias * (self.bias_scale * self.scale_speed).exp())
|
||||
|
||||
def _conv_forward(self, input, weight):
|
||||
F = torch.nn.functional
|
||||
if self.padding_mode != 'zeros':
|
||||
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||||
weight, self.get_bias(), self.stride,
|
||||
_pair(0), self.dilation, self.groups)
|
||||
return F.conv2d(input, weight, self.get_bias(), self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return self._conv_forward(input, self.get_weight())
|
||||
|
||||
|
||||
|
||||
|
||||
class ActivationBalancer(torch.nn.Module):
|
||||
"""
|
||||
Modifies the backpropped derivatives of a function to try to encourage, for
|
||||
each channel, that it is positive at least a proportion `threshold` of the
|
||||
time. It does this by multiplying negative derivative values by up to
|
||||
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
||||
interpolated from 1 at the threshold to those extremal values when none
|
||||
of the inputs are positive.
|
||||
|
||||
|
||||
Args:
|
||||
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
||||
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
||||
min_positive: the minimum, per channel, of the proportion of the time
|
||||
that (x > 0), below which we start to modify the derivatives.
|
||||
max_positive: the maximum, per channel, of the proportion of the time
|
||||
that (x > 0), below which we start to modify the derivatives.
|
||||
max_factor: the maximum factor by which we modify the derivatives for
|
||||
either the sign constraint or the magnitude constraint;
|
||||
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
||||
values in the range [0.98..1.02].
|
||||
min_abs: the minimum average-absolute-value per channel, which
|
||||
we allow, before we start to modify the derivatives to prevent
|
||||
this.
|
||||
max_abs: the maximum average-absolute-value per channel, which
|
||||
we allow, before we start to modify the derivatives to prevent
|
||||
this.
|
||||
"""
|
||||
def __init__(self, channel_dim: int,
|
||||
min_positive: float = 0.05,
|
||||
max_positive: float = 0.95,
|
||||
max_factor: float = 0.01,
|
||||
min_abs: float = 0.2,
|
||||
max_abs: float = 100.0):
|
||||
super(ActivationBalancer, self).__init__()
|
||||
self.channel_dim = channel_dim
|
||||
self.min_positive = min_positive
|
||||
self.max_positive = max_positive
|
||||
self.max_factor = max_factor
|
||||
self.min_abs = min_abs
|
||||
self.max_abs = max_abs
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return ActivationBalancerFunction.apply(x, self.channel_dim,
|
||||
self.min_positive, self.max_positive,
|
||||
self.max_factor, self.min_abs,
|
||||
self.max_abs)
|
||||
|
||||
|
||||
def _double_swish(x: Tensor) -> Tensor:
|
||||
# double-swish, implemented/approximated as offset-swish
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
|
||||
class DoubleSwishFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor) -> Tensor:
|
||||
ctx.save_for_backward(x.detach())
|
||||
return _double_swish(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||
# TODO: can make this more efficient.
|
||||
x, = ctx.saved_tensors
|
||||
x.requires_grad = True
|
||||
with torch.enable_grad():
|
||||
y = _double_swish(x)
|
||||
y.backward(gradient=y_grad)
|
||||
return x.grad
|
||||
|
||||
class DoubleSwish(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||
that we approximate closely with x * sigmoid(x-1).
|
||||
"""
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
|
||||
|
||||
def _test_deriv_balancer_sign():
|
||||
channel_dim = 0
|
||||
probs = torch.arange(0, 1, 0.01)
|
||||
N = 1000
|
||||
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
|
||||
max_factor=0.2, min_abs=0.0)
|
||||
|
||||
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
||||
|
||||
y = m(x)
|
||||
y.backward(gradient=y_grad)
|
||||
print("_test_deriv_balancer_sign: x = ", x)
|
||||
print("_test_deriv_balancer_sign: y grad = ", y_grad)
|
||||
print("_test_deriv_balancer_sign: x grad = ", x.grad)
|
||||
|
||||
def _test_deriv_balancer_magnitude():
|
||||
channel_dim = 0
|
||||
magnitudes = torch.arange(0, 1, 0.01)
|
||||
N = 1000
|
||||
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = ActivationBalancer(channel_dim=0,
|
||||
min_positive=0.0, max_positive=1.0,
|
||||
max_factor=0.2,
|
||||
min_abs=0.2, max_abs=0.8)
|
||||
|
||||
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
||||
|
||||
y = m(x)
|
||||
y.backward(gradient=y_grad)
|
||||
print("_test_deriv_balancer_magnitude: x = ", x)
|
||||
print("_test_deriv_balancer_magnitude: y grad = ", y_grad)
|
||||
print("_test_deriv_balancer_magnitude: x grad = ", x.grad)
|
||||
|
||||
|
||||
def _test_basic_norm():
|
||||
num_channels = 128
|
||||
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
||||
|
||||
x = torch.randn(500, num_channels)
|
||||
|
||||
y = m(x)
|
||||
|
||||
assert y.shape == x.shape
|
||||
x_rms = (x**2).mean().sqrt()
|
||||
y_rms = (y**2).mean().sqrt()
|
||||
print("x rms = ", x_rms)
|
||||
print("y rms = ", y_rms)
|
||||
assert y_rms < x_rms
|
||||
assert y_rms > 0.5 * x_rms
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_deriv_balancer_sign()
|
||||
_test_deriv_balancer_magnitude()
|
||||
_test_basic_norm()
|
||||
|
||||
@ -22,8 +22,6 @@ import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
import torch
|
||||
from lhotse.utils import fix_random_seed
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||
|
||||
@ -1,338 +0,0 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
import math
|
||||
import random
|
||||
from typing import Tuple, List
|
||||
|
||||
|
||||
class TensorDiagnosticOptions(object):
|
||||
"""
|
||||
Options object for tensor diagnostics:
|
||||
|
||||
Args:
|
||||
memory_limit: the maximum number of bytes we store per tensor (limits how many copies
|
||||
of the tensor we cache).
|
||||
max_eig_dim: the maximum dimension for which we print out eigenvalues
|
||||
(limited for speed reasons).
|
||||
"""
|
||||
def __init__(self,
|
||||
memory_limit: int = (2 ** 20),
|
||||
max_eig_dim: int = 512):
|
||||
|
||||
self.memory_limit = memory_limit
|
||||
self.max_eig_dim = max_eig_dim
|
||||
|
||||
def dim_is_summarized(self, size: int):
|
||||
return size > 10 and size != 31
|
||||
|
||||
|
||||
|
||||
def get_tensor_stats(x: Tensor, dim: int,
|
||||
stats_type: str) -> Tuple[Tensor, int]:
|
||||
"""
|
||||
Returns the specified transformation of the Tensor (either x or x.abs()
|
||||
or (x > 0), summed over all but the index `dim`.
|
||||
|
||||
Args:
|
||||
x: Tensor, tensor to be analyzed
|
||||
dim: dimension with 0 <= dim < x.ndim
|
||||
stats_type:
|
||||
"abs" -> take abs() before summing
|
||||
"positive" -> take (x > 0) before summing
|
||||
"rms" -> square before summing, we'll take sqrt later
|
||||
"value -> just sum x itself
|
||||
Returns (stats, count)
|
||||
where stats is a Tensor of shape (x.shape[dim],), and the count
|
||||
is an integer saying how many items were counted in each element
|
||||
of stats.
|
||||
"""
|
||||
count = x.numel() // x.shape[dim]
|
||||
|
||||
if stats_type == "eigs":
|
||||
x = x.transpose(dim, -1)
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
# shape of returned tensor: (s, s) where s is size of dimension `dim` of original x.
|
||||
return torch.matmul(x.transpose(0, 1), x), count
|
||||
elif stats_type == "abs":
|
||||
x = x.abs()
|
||||
elif stats_type == "rms":
|
||||
x = x ** 2
|
||||
elif stats_type == "positive":
|
||||
x = (x > 0).to(dtype=torch.float)
|
||||
else:
|
||||
assert stats_type == "value"
|
||||
|
||||
sum_dims = [ d for d in range(x.ndim) if d != dim ]
|
||||
if len(sum_dims) > 0:
|
||||
x = torch.sum(x, dim=sum_dims)
|
||||
x = x.flatten()
|
||||
return x, count
|
||||
|
||||
def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
|
||||
options: TensorDiagnosticOptions,
|
||||
sizes_same: bool,
|
||||
stats_type: str):
|
||||
"""
|
||||
This function gets diagnostics for a dimension of a module.
|
||||
Args:
|
||||
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
|
||||
options: options object
|
||||
sizes_same: true if all the tensor sizes are the same on this dimension
|
||||
stats_type: either "abs" or "positive" or "eigs" or "value",
|
||||
imdictates the type of stats
|
||||
we accumulate, abs is mean absolute value, "positive"
|
||||
is proportion of positive to nonnegative values, "eigs"
|
||||
is eigenvalues after doing outer product on this dim, sum
|
||||
over all other dimes.
|
||||
Returns:
|
||||
Diagnostic as a string, either percentiles or the actual values,
|
||||
see the code. Will return the empty string if the diagnostics did
|
||||
not make sense to print out for this dimension, e.g. dimension
|
||||
mismatch and stats_type == "eigs"
|
||||
"""
|
||||
# stats_and_counts is a list of pair (Tensor, int)
|
||||
stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ]
|
||||
stats = [ x[0] for x in stats_and_counts ]
|
||||
counts = [ x[1] for x in stats_and_counts ]
|
||||
|
||||
if stats_type == "eigs":
|
||||
try:
|
||||
stats = torch.stack(stats).sum(dim=0)
|
||||
except:
|
||||
return ''
|
||||
count = sum(counts)
|
||||
stats = stats / count
|
||||
stats, _ = torch.symeig(stats)
|
||||
stats = stats.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance
|
||||
elif sizes_same:
|
||||
stats = torch.stack(stats).sum(dim=0)
|
||||
count = sum(counts)
|
||||
stats = stats / count
|
||||
else:
|
||||
stats = [ x[0] / x[1] for x in stats_and_counts ]
|
||||
stats = torch.cat(stats, dim=0)
|
||||
if stats_type == 'rms':
|
||||
stats = stats.sqrt()
|
||||
|
||||
# if `summarize` we print percentiles of the stats; else,
|
||||
# we print out individual elements.
|
||||
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
|
||||
if summarize:
|
||||
# print out percentiles.
|
||||
stats = stats.sort()[0]
|
||||
num_percentiles = 10
|
||||
size = stats.numel()
|
||||
percentiles = []
|
||||
for i in range(num_percentiles + 1):
|
||||
index = (i * (size - 1)) // num_percentiles
|
||||
percentiles.append(stats[index].item())
|
||||
percentiles = [ '%.2g' % x for x in percentiles ]
|
||||
percentiles = ' '.join(percentiles)
|
||||
ans = f'percentiles: [{percentiles}]'
|
||||
else:
|
||||
ans = stats.tolist()
|
||||
ans = [ '%.2g' % x for x in ans ]
|
||||
ans = '[' + ' '.join(ans) + ']'
|
||||
if stats_type == "value":
|
||||
# This norm is useful because it is strictly less than the largest
|
||||
# sqrt(eigenvalue) of the variance, which we print out, and shows,
|
||||
# speaking in an approximate way, how much of that largest eigenvalue
|
||||
# can be attributed to the mean of the distribution.
|
||||
norm = (stats ** 2).sum().sqrt().item()
|
||||
mean = stats.mean().item()
|
||||
rms = (stats ** 2).mean().sqrt().item()
|
||||
ans += f', norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}'
|
||||
else:
|
||||
mean = stats.mean().item()
|
||||
rms = (stats ** 2).mean().sqrt().item()
|
||||
ans += f', mean={mean:.2g}, rms={rms:.2g}'
|
||||
return ans
|
||||
|
||||
|
||||
|
||||
def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
|
||||
options: TensorDiagnosticOptions):
|
||||
ndim = tensors[0].ndim
|
||||
if ndim > 1:
|
||||
stats_types = ["abs", "positive", "value", "rms"]
|
||||
if tensors[0].shape[dim] <= options.max_eig_dim:
|
||||
stats_types.append("eigs")
|
||||
else:
|
||||
stats_types = [ "value", "abs" ]
|
||||
|
||||
for stats_type in stats_types:
|
||||
sizes = [ x.shape[dim] for x in tensors ]
|
||||
sizes_same = all([ x == sizes[0] for x in sizes ])
|
||||
s = get_diagnostics_for_dim(dim, tensors,
|
||||
options, sizes_same,
|
||||
stats_type)
|
||||
if s == '':
|
||||
continue
|
||||
|
||||
min_size = min(sizes)
|
||||
max_size = max(sizes)
|
||||
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
|
||||
# stats_type will be "abs" or "positive".
|
||||
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
|
||||
|
||||
|
||||
class TensorDiagnostic(object):
|
||||
"""
|
||||
This class is not directly used by the user, it is responsible for collecting
|
||||
diagnostics for a single parameter tensor of a torch.Module.
|
||||
"""
|
||||
def __init__(self,
|
||||
opts: TensorDiagnosticOptions,
|
||||
name: str):
|
||||
self.name = name
|
||||
self.opts = opts
|
||||
self.saved_tensors = []
|
||||
|
||||
def accumulate(self, x):
|
||||
if isinstance(x, Tuple):
|
||||
x = x[0]
|
||||
if not isinstance(x, Tensor):
|
||||
return
|
||||
if x.device == torch.device('cpu'):
|
||||
x = x.detach().clone()
|
||||
else:
|
||||
x = x.detach().to('cpu', non_blocking=True)
|
||||
self.saved_tensors.append(x)
|
||||
l = len(self.saved_tensors)
|
||||
if l & (l - 1) == 0: # power of 2..
|
||||
self._limit_memory()
|
||||
|
||||
def _limit_memory(self):
|
||||
if len(self.saved_tensors) > 1024:
|
||||
self.saved_tensors = self.saved_tensors[-1024:]
|
||||
return
|
||||
|
||||
tot_mem = 0.0
|
||||
for i in reversed(range(len(self.saved_tensors))):
|
||||
tot_mem += self.saved_tensors[i].numel() * self.saved_tensors[i].element_size()
|
||||
if tot_mem > self.opts.memory_limit:
|
||||
self.saved_tensors = self.saved_tensors[i:]
|
||||
return
|
||||
|
||||
def print_diagnostics(self):
|
||||
if len(self.saved_tensors) == 0:
|
||||
print("{name}: no stats".format(name=self.name))
|
||||
return
|
||||
if self.saved_tensors[0].ndim == 0:
|
||||
# ensure there is at least one dim.
|
||||
self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ]
|
||||
|
||||
try:
|
||||
device = torch.device('cuda')
|
||||
torch.ones(1, 1, device)
|
||||
except:
|
||||
device = torch.device('cpu')
|
||||
|
||||
ndim = self.saved_tensors[0].ndim
|
||||
tensors = [x.to(device) for x in self.saved_tensors]
|
||||
for dim in range(ndim):
|
||||
print_diagnostics_for_dim(self.name, dim,
|
||||
tensors,
|
||||
self.opts)
|
||||
|
||||
|
||||
class ModelDiagnostic(object):
|
||||
def __init__(self, opts: TensorDiagnosticOptions = TensorDiagnosticOptions()):
|
||||
self.diagnostics = dict()
|
||||
self.opts = opts
|
||||
|
||||
def __getitem__(self, name: str):
|
||||
if name not in self.diagnostics:
|
||||
self.diagnostics[name] = TensorDiagnostic(self.opts, name)
|
||||
return self.diagnostics[name]
|
||||
|
||||
def print_diagnostics(self):
|
||||
for k in sorted(self.diagnostics.keys()):
|
||||
self.diagnostics[k].print_diagnostics()
|
||||
|
||||
|
||||
|
||||
def attach_diagnostics(model: nn.Module,
|
||||
opts: TensorDiagnosticOptions) -> ModelDiagnostic:
|
||||
ans = ModelDiagnostic(opts)
|
||||
for name, module in model.named_modules():
|
||||
if name == '':
|
||||
name = "<top-level>"
|
||||
forward_diagnostic = TensorDiagnostic(opts, name + ".output")
|
||||
backward_diagnostic = TensorDiagnostic(opts, name + ".grad")
|
||||
|
||||
|
||||
# setting model_diagnostic=ans and n=name below, instead of trying to capture the variables,
|
||||
# ensures that we use the current values. (matters for name, since
|
||||
# the variable gets overwritten). these closures don't really capture
|
||||
# by value, only by "the final value the variable got in the function" :-(
|
||||
def forward_hook(_module, _input, _output,
|
||||
_model_diagnostic=ans, _name=name):
|
||||
if isinstance(_output, Tensor):
|
||||
_model_diagnostic[f"{_name}.output"].accumulate(_output)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o)
|
||||
|
||||
def backward_hook(_module, _input, _output,
|
||||
_model_diagnostic=ans, _name=name):
|
||||
if isinstance(_output, Tensor):
|
||||
_model_diagnostic[f"{_name}.grad"].accumulate(_output)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o)
|
||||
|
||||
module.register_forward_hook(forward_hook)
|
||||
module.register_backward_hook(backward_hook)
|
||||
|
||||
for name, parameter in model.named_parameters():
|
||||
|
||||
def param_backward_hook(grad,
|
||||
_parameter=parameter,
|
||||
_model_diagnostic=ans,
|
||||
_name=name):
|
||||
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
|
||||
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
|
||||
|
||||
parameter.register_hook(param_backward_hook)
|
||||
return ans
|
||||
|
||||
|
||||
|
||||
def _test_tensor_diagnostic():
|
||||
opts = TensorDiagnosticOptions(2**20, 512)
|
||||
|
||||
diagnostic = TensorDiagnostic(opts, "foo")
|
||||
|
||||
for _ in range(10):
|
||||
diagnostic.accumulate(torch.randn(50, 100) * 10.0)
|
||||
|
||||
diagnostic.print_diagnostics()
|
||||
|
||||
model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
|
||||
|
||||
diagnostic = attach_diagnostics(model, opts)
|
||||
for _ in range(10):
|
||||
T = random.randint(200, 300)
|
||||
x = torch.randn(T, 100)
|
||||
y = model(x)
|
||||
y.sum().backward()
|
||||
|
||||
diagnostic.print_diagnostics()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_tensor_diagnostic()
|
||||
|
||||
|
||||
def _test_func():
|
||||
ans = []
|
||||
for i in range(10):
|
||||
x = list()
|
||||
x.append(i)
|
||||
def func():
|
||||
return x
|
||||
ans.append(func)
|
||||
return ans
|
||||
Loading…
x
Reference in New Issue
Block a user