mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make ActivationBalancer and MaxEig more efficient.
This commit is contained in:
parent
1825336841
commit
12323025d7
@ -27,6 +27,7 @@ from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
MaxEig,
|
||||
DoubleSwish,
|
||||
ScaledConv1d,
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
@ -293,8 +294,11 @@ class ConformerEncoderLayer(nn.Module):
|
||||
d_model, channel_dim=-1,
|
||||
min_positive=0.45, max_positive=0.55,
|
||||
max_abs=6.0,
|
||||
max_var_per_eig=0.2,
|
||||
)
|
||||
self.max_eig = MaxEig(
|
||||
d_model, channel_dim=-1,
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -350,7 +354,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
src = src + self.feed_forward3(src)
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
src = self.norm_final(self.max_eig(self.balancer(src)))
|
||||
|
||||
delta = src - src_orig
|
||||
bypass_scale = self.bypass_scale
|
||||
@ -838,8 +842,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
|
||||
self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
|
||||
channel_dim=-1, max_abs=5.0,
|
||||
max_var_per_eig=0.2)
|
||||
channel_dim=-1, max_abs=5.0)
|
||||
self.in_max_eig = MaxEig(3 * embed_dim // 2,
|
||||
channel_dim=-1)
|
||||
self.proj_balancer = ActivationBalancer(embed_dim // 2,
|
||||
channel_dim=-1, max_abs=10.0,
|
||||
min_positive=0.0, max_positive=1.0)
|
||||
@ -915,7 +920,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
before softmax.
|
||||
"""
|
||||
x, weights, scores = self.multi_head_attention_forward(
|
||||
self.in_balancer(self.in_proj(x)),
|
||||
self.in_max_eig(self.in_balancer(self.in_proj(x))),
|
||||
pos_emb,
|
||||
None if attn_scores_in is None else torch.matmul(attn_scores_in, self.attn_scores_proj_in),
|
||||
self.embed_dim,
|
||||
|
||||
@ -29,120 +29,35 @@ from torch import Tensor
|
||||
from torch.nn import Embedding as ScaledEmbedding
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
_single = _ntuple(1)
|
||||
_pair = _ntuple(2)
|
||||
|
||||
|
||||
class ActivationBalancerFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
channel_dim: int,
|
||||
min_positive: float, # e.g. 0.05
|
||||
max_positive: float, # e.g. 0.95
|
||||
max_factor: float, # e.g. 0.01
|
||||
min_abs: float, # e.g. 0.2
|
||||
max_abs: float, # e.g. 100.0
|
||||
ctx,
|
||||
x: Tensor,
|
||||
sign_factor: Tensor,
|
||||
scale_factor: Tensor,
|
||||
channel_dim: int,
|
||||
) -> Tensor:
|
||||
if x.requires_grad:
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||
x_normalized = x - torch.mean(x, dim=sum_dims, keepdim=True)
|
||||
xgtmean = (x_normalized > 0)
|
||||
proportion_positive = torch.mean(
|
||||
(x > 0).to(x.dtype), dim=sum_dims, keepdim=True
|
||||
)
|
||||
factor1 = (
|
||||
(min_positive - proportion_positive).relu()
|
||||
* (max_factor / min_positive)
|
||||
if min_positive != 0.0
|
||||
else 0.0
|
||||
)
|
||||
factor2 = (
|
||||
(proportion_positive - max_positive).relu()
|
||||
* (max_factor / (max_positive - 1.0))
|
||||
if max_positive != 1.0
|
||||
else 0.0
|
||||
)
|
||||
# `factor` is a tensor of shape something like (1, 1, num_channels,
|
||||
# 1), containing elements between -1 and 1 that are zero if the
|
||||
# proportion of positive features is between min_positive and
|
||||
# max_positive, max_factor if proportion==0.0 (all features are negative),
|
||||
# and -max_factor if proportion==1.0 (all features are positive). It is
|
||||
# an amount per channel by which we'll modify the gradient; the sign
|
||||
# of modifying the gradient will depend on the sign of the gradient.
|
||||
|
||||
factor = factor1 + factor2
|
||||
if isinstance(factor, float):
|
||||
factor = torch.zeros_like(proportion_positive)
|
||||
|
||||
mean_abs = torch.mean(x_normalized.abs(), dim=sum_dims, keepdim=True)
|
||||
below_threshold = mean_abs < min_abs
|
||||
above_threshold = mean_abs > max_abs
|
||||
|
||||
ctx.save_for_backward(
|
||||
factor, xgtmean, below_threshold, above_threshold
|
||||
)
|
||||
ctx.max_factor = max_factor
|
||||
ctx.sum_dims = sum_dims
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
ctx.channel_dim = channel_dim
|
||||
xgt0 = (x > 0)
|
||||
ctx.save_for_backward(xgt0, sign_factor, scale_factor)
|
||||
return x
|
||||
|
||||
|
||||
@staticmethod
|
||||
def backward(
|
||||
ctx, x_grad: Tensor
|
||||
) -> Tuple[Tensor, None, None, None, None, None, None]:
|
||||
factor, xgtmean, below_threshold, above_threshold = ctx.saved_tensors
|
||||
dtype = x_grad.dtype
|
||||
scale_factor = (
|
||||
(below_threshold.to(dtype) - above_threshold.to(dtype))
|
||||
* (xgtmean.to(dtype) - 0.5)
|
||||
* (ctx.max_factor * 2.0)
|
||||
)
|
||||
) -> Tuple[Tensor, None, None, None]:
|
||||
xgt0, sign_factor, scale_factor = ctx.saved_tensors
|
||||
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
||||
sign_factor = sign_factor.unsqueeze(-1)
|
||||
scale_factor = scale_factor.unsqueeze(-1)
|
||||
|
||||
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
||||
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
||||
|
||||
|
||||
def find_direction_coeffs(x: Tensor,
|
||||
prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
Figure out (an approximation to) the proportion of the variance of a set of
|
||||
feature vectors that can be attributed to the top eigen-direction.
|
||||
Args:
|
||||
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
|
||||
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
|
||||
of the top eigen-direction, or a random direction if this is the first
|
||||
iteration. Does not have to be normalized, but should be nonzero.
|
||||
|
||||
Returns: (cur_direction, coeffs), where:
|
||||
cur_direction: a Tensor of shape (num_channels,) that is the current
|
||||
estimate of the top eigen-direction.
|
||||
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
|
||||
approximately minimizes, (x - coeffs * cur_direction).norm()
|
||||
"""
|
||||
(num_frames, num_channels) = x.shape
|
||||
assert num_channels > 1 and num_frames > 1
|
||||
|
||||
assert prev_direction.shape == (num_channels,)
|
||||
|
||||
# `coeffs` are the coefficients of `prev_direction` in x.
|
||||
# actually represent the coeffs up to a constant positive factor.
|
||||
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
|
||||
|
||||
cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20)
|
||||
|
||||
return cur_direction, coeffs
|
||||
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
||||
neg_delta_grad = x_grad.abs() * factor
|
||||
return x_grad - neg_delta_grad, None, None, None,
|
||||
|
||||
|
||||
|
||||
@ -152,57 +67,27 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
coeffs: Tensor,
|
||||
direction: Tensor,
|
||||
channel_dim: int,
|
||||
subtract_mean: bool,
|
||||
max_variance_proportion: float,
|
||||
grad_scale: float) -> Tuple[Tensor, Tensor]:
|
||||
eps = 1.0e-20
|
||||
num_channels = x.shape[channel_dim]
|
||||
assert max_variance_proportion > 1.0 / num_channels
|
||||
orig_x = x
|
||||
x = x.transpose(channel_dim, -1).reshape(-1, num_channels)
|
||||
if subtract_mean:
|
||||
x = x - x.mean(dim=0)
|
||||
new_direction, coeffs = find_direction_coeffs(x, direction)
|
||||
x_var = (x**2).mean()
|
||||
x_residual = x - coeffs * new_direction
|
||||
x_residual_var = (x_residual**2).mean()
|
||||
# `variance_proportion` is the proportion of the variance accounted for
|
||||
# by the top eigen-direction.
|
||||
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
||||
grad_scale: float) -> Tensor:
|
||||
ctx.channel_dim = channel_dim
|
||||
ctx.grad_scale = grad_scale
|
||||
ctx.save_for_backward(x.detach(),
|
||||
coeffs.detach(),
|
||||
direction.detach())
|
||||
return x
|
||||
|
||||
ans_direction = direction + new_direction # ensure nonzero even if x == 0
|
||||
ans_direction = ans_direction / ans_direction.norm()
|
||||
|
||||
if random.random() < 0.0005:
|
||||
logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(x.shape)}")
|
||||
|
||||
# Caution: this causes a CUDA sync, which is not ideal.
|
||||
if variance_proportion >= max_variance_proportion:
|
||||
ctx.channel_dim = channel_dim
|
||||
ctx.subtract_mean = subtract_mean
|
||||
ctx.grad_scale = grad_scale
|
||||
ctx.save_for_backward(orig_x.detach(),
|
||||
coeffs.detach(),
|
||||
new_direction.detach())
|
||||
|
||||
return orig_x, ans_direction
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, x_grad, *args):
|
||||
# the *args is all the other derivs, which should be None or zero.
|
||||
if not hasattr(ctx, 'channel_dim'):
|
||||
# the top eig's proportion of the variance was below the threshold.
|
||||
return x_grad, None, None, None, None, None, None
|
||||
with torch.enable_grad():
|
||||
(x_orig, coeffs, new_direction) = ctx.saved_tensors
|
||||
x_orig.requires_grad = True
|
||||
num_channels = x_orig.shape[ctx.channel_dim]
|
||||
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
|
||||
new_direction.requires_grad = False
|
||||
if ctx.subtract_mean:
|
||||
x = x - x.mean(dim=0)
|
||||
x = x - x.mean(dim=0)
|
||||
x_var = (x ** 2).mean()
|
||||
x_residual = x - coeffs * new_direction
|
||||
x_residual_var = (x_residual ** 2).mean()
|
||||
@ -212,7 +97,7 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
||||
variance_proportion.backward()
|
||||
x_orig_grad = x_orig.grad
|
||||
x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20)
|
||||
return x_grad + x_extra_grad.detach(), None, None, None, None, None, None
|
||||
return x_grad + x_extra_grad.detach(), None, None, None, None
|
||||
|
||||
|
||||
class BasicNorm(torch.nn.Module):
|
||||
@ -352,11 +237,15 @@ class ActivationBalancer(torch.nn.Module):
|
||||
max_abs: the maximum average-absolute-value difference from the mean
|
||||
value per channel, which we allow, before we start to modify
|
||||
the derivatives to prevent this.
|
||||
max_var_per_eig: the maximum proportion of the variance of the
|
||||
features/channels, after mean subtraction, that can come from
|
||||
any given eigenvalue.
|
||||
beta: a constant used in decaying stats for the {min,max}_positive and
|
||||
{min,max}_abs constraints. Likely not critical.
|
||||
prob: determines the probability with which we modify the
|
||||
gradients for the {min,max}_positive and {min,max}_abs constraints,
|
||||
on each forward(). This is done randomly to prevent all layers
|
||||
from doing it at the same time.
|
||||
stats_period: the periodicity with which we update the statistics on
|
||||
the activations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
@ -367,6 +256,9 @@ class ActivationBalancer(torch.nn.Module):
|
||||
min_abs: float = 0.2,
|
||||
max_abs: float = 100.0,
|
||||
max_var_per_eig: float = 0.0,
|
||||
beta: float = 0.75,
|
||||
prob: float = 0.25,
|
||||
stats_period: int = 10,
|
||||
):
|
||||
super(ActivationBalancer, self).__init__()
|
||||
self.num_channels = num_channels
|
||||
@ -376,49 +268,261 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self.max_factor = max_factor
|
||||
self.min_abs = min_abs
|
||||
self.max_abs = max_abs
|
||||
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
|
||||
self.max_var_per_eig = max_var_per_eig
|
||||
if max_var_per_eig > 0.0:
|
||||
with torch.no_grad():
|
||||
# arbitrary.. would use randn() but want to leave the rest of the model's
|
||||
# random parameters unchanged for comparison
|
||||
direction = torch.arange(num_channels).to(torch.float)
|
||||
direction = direction / direction.norm()
|
||||
self.register_buffer('max_eig_direction', direction)
|
||||
else:
|
||||
self.max_eig_direction = None
|
||||
self.beta = beta
|
||||
self.prob = prob
|
||||
self.stats_period = stats_period
|
||||
|
||||
# count measures how many times the forward() function has been called.
|
||||
self.count = 0
|
||||
|
||||
# the mean of the absolute value of the data per channel
|
||||
self.register_buffer('abs_mean', torch.zeros(num_channels))
|
||||
|
||||
# the proportion of activations that are positive, per channel.
|
||||
self.register_buffer('proportion_positive', torch.zeros(num_channels))
|
||||
|
||||
# `factors` contains two buffers of shape (num_channels,).
|
||||
# `sign_factor` is an expression that will be used to scale the
|
||||
# gradients in backprop; it will be 0 if the max_positive and min_positive
|
||||
# contstraints are satisfied.
|
||||
# `scale_factor` is an expression that will be used to encourage the
|
||||
# data to satisfy our min_abs and max_abs constraints; it will be zero if
|
||||
# all constraints are satisfied.
|
||||
self.register_buffer('factors', torch.zeros(2, num_channels))
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or not x.requires_grad:
|
||||
return x
|
||||
|
||||
max_eig_prob = 0.25
|
||||
if self.max_var_per_eig > 0 and random.random() < max_eig_prob:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x, new_direction = MaxEigLimiterFunction.apply(
|
||||
x, self.max_eig_direction,
|
||||
self.channel_dim,
|
||||
True, # subtract_mean
|
||||
self.max_var_per_eig,
|
||||
self.max_factor / max_eig_prob,
|
||||
)
|
||||
self.max_eig_direction[:] = new_direction.detach()
|
||||
count = self.count
|
||||
self.count += 1
|
||||
|
||||
balance_prob = 0.25
|
||||
if random.random() < balance_prob:
|
||||
if count % self.stats_period == 0:
|
||||
self._update_stats(x, count)
|
||||
|
||||
if random.random() < self.prob:
|
||||
# The .clone() is in case the forward() gets called multiple times befor
|
||||
factors = self.factors.clone()
|
||||
sign_factor = factors[0]
|
||||
scale_factor = factors[1]
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
self.channel_dim,
|
||||
self.min_positive,
|
||||
self.max_positive,
|
||||
self.max_factor / balance_prob,
|
||||
self.min_abs,
|
||||
self.max_abs,
|
||||
x, sign_factor, scale_factor, self.channel_dim,
|
||||
)
|
||||
else:
|
||||
return x
|
||||
|
||||
def _update_stats(self,
|
||||
x: Tensor,
|
||||
count: int):
|
||||
"""
|
||||
Updates some statistics that we maintain, describing the average activations per
|
||||
channel.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
sum_dims = [d for d in range(x.ndim) if d != self.channel_dim]
|
||||
|
||||
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
||||
# the random.random() thing is to split the difference if x is zero,
|
||||
# between treating it positive or negative
|
||||
proportion_positive = torch.mean(
|
||||
((x > 0) if random.random() < 0.5 else (x >= 0)).to(torch.float32), dim = sum_dims,
|
||||
)
|
||||
|
||||
def filter_inf_nan(y):
|
||||
mask = (y - y != 0)
|
||||
y.masked_fill_(mask, 0.0)
|
||||
|
||||
filter_inf_nan(x_abs_mean)
|
||||
|
||||
beta = self.beta if count > 0 else 0.0
|
||||
self.abs_mean.mul_(beta).add_(x_abs_mean, alpha=(1-beta))
|
||||
self.proportion_positive.mul_(beta).add_(proportion_positive, alpha=(1-beta))
|
||||
|
||||
max_factor = self.max_factor / self.prob
|
||||
min_positive = self.min_positive
|
||||
max_positive = self.max_positive
|
||||
|
||||
if min_positive == 0.0:
|
||||
factor1 = 0.0
|
||||
else:
|
||||
# 0 if self.proportion_positive >= min_positive, else can be
|
||||
# as large as max_factor.
|
||||
factor1 = ((min_positive - self.proportion_positive).relu() *
|
||||
(max_factor / min_positive))
|
||||
if max_positive == 1.0:
|
||||
factor2 = 0.0
|
||||
else:
|
||||
# 0 if self.proportion_positive <= max_positive, else can be
|
||||
# as large as -max_factor.
|
||||
factor2 = ((self.proportion_positive - max_positive).relu()
|
||||
* (max_factor / (max_positive - 1.0)))
|
||||
sign_factor = self.factors[0]
|
||||
scale_factor = self.factors[1]
|
||||
sign_factor[:] = factor1 + factor2
|
||||
|
||||
# the factor of 2.0 below is just to cancel out a factor of 0.5 that gets introduced when, in
|
||||
# the backprop, we do (xgt0.to(dtype) - 0.5).
|
||||
#
|
||||
# scale_factor_scale, on the other hand, is a heuristically chosen value between 0 and 1,
|
||||
# that we use to make the gradient changes from the 'scale' constraints (min_abs/max_abs)
|
||||
# less strong than those from the sign constraints.
|
||||
#
|
||||
# This is to get rid of a pathology that can happen if, for instance, a
|
||||
# channel is always positive but is too small (max_positive and min_abs constraints both
|
||||
# violated). If scale_factor_scale were equal to 1.0, then the gradient changes from the
|
||||
# min_positive constraint (trying to make the activation more negative) and from the
|
||||
# min_abs constraint (trying to make the activation more positive) would exactly cancel.
|
||||
# Instead we make the min_positive constraint stronger, so it first makes the value
|
||||
# sometimes negative, and only when that is satisfied, can deal with the absolute-value
|
||||
# constraint.
|
||||
scale_factor_scale = 0.5
|
||||
below_threshold = (self.abs_mean < self.min_abs)
|
||||
above_threshold = (self.abs_mean > self.max_abs)
|
||||
scale_factor[:] = ((below_threshold.to(torch.float32) -
|
||||
above_threshold.to(torch.float32))
|
||||
* (max_factor * (2.0 * scale_factor_scale)))
|
||||
|
||||
|
||||
class MaxEig(torch.nn.Module):
|
||||
"""
|
||||
Modifies the backpropped derivatives of a function to try to discourage
|
||||
that any given direction in activation space accounts for more than
|
||||
a specified proportion of the covariance (e.g. 0.2).
|
||||
|
||||
|
||||
Args:
|
||||
num_channels: the number of channels
|
||||
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
||||
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
||||
max_var_per_eig: the maximum proportion of the variance of the
|
||||
features/channels, after mean subtraction, that can come from
|
||||
any given eigenvalue.
|
||||
min_prob: the minimum probability with which we apply this during any invocation
|
||||
of forward(), assuming last time we applied the constraint it was
|
||||
not active; supplied for speed.
|
||||
scale: determines the scale with which we modify the gradients, relative
|
||||
to the existing / unmodified gradients
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
channel_dim: int,
|
||||
max_var_per_eig: float = 0.2,
|
||||
min_prob: float = 0.01,
|
||||
scale: float = 0.01,
|
||||
):
|
||||
super(MaxEig, self).__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
self.scale = scale
|
||||
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
|
||||
self.max_var_per_eig = max_var_per_eig
|
||||
|
||||
# we figure out the dominant direction using the power method: starting with
|
||||
# a random vector, keep multiplying by the covariance and renormalizing.
|
||||
with torch.no_grad():
|
||||
# arbitrary.. would use randn() but want to leave the rest of the model's
|
||||
# random parameters unchanged for comparison
|
||||
direction = torch.arange(num_channels).to(torch.float)
|
||||
direction = direction / direction.norm()
|
||||
self.register_buffer('max_eig_direction', direction)
|
||||
|
||||
self.min_prob = min_prob
|
||||
# cur_prob is the current probability we'll use to apply the ActivationBalancer.
|
||||
# We'll regress this towards prob, each tiem we try to apply it and it is not
|
||||
# active.
|
||||
self.cur_prob = 1.0
|
||||
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if (torch.jit.is_scripting() or
|
||||
self.max_var_per_eig <= 0 or
|
||||
random.random() > self.cur_prob):
|
||||
return x
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
eps = 1.0e-20
|
||||
assert x.dtype != torch.float16
|
||||
orig_x = x
|
||||
with torch.no_grad():
|
||||
x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels)
|
||||
x = x - x.mean(dim=0)
|
||||
new_direction, coeffs = self._find_direction_coeffs(x, self.max_eig_direction)
|
||||
x_var = (x**2).mean()
|
||||
x_residual = x - coeffs * new_direction
|
||||
x_residual_var = (x_residual**2).mean()
|
||||
|
||||
# `variance_proportion` is the proportion of the variance accounted for
|
||||
# by the top eigen-direction.
|
||||
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
||||
|
||||
# ensure new direction is nonzero even if x == 0, by including `direction`.
|
||||
self._set_direction(0.1 * self.max_eig_direction + new_direction)
|
||||
|
||||
if random.random() < 0.0005 or __name__ == "__main__":
|
||||
logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}")
|
||||
|
||||
if variance_proportion >= self.max_var_per_eig:
|
||||
# The constraint is active. Note, we should quite rarely
|
||||
# reach here, only near the beginning of training if we are
|
||||
# starting to diverge, should this constraint be active.
|
||||
cur_prob = self.cur_prob
|
||||
self.cur_prob = 1.0 # next time, do the update with probability 1.0.
|
||||
return MaxEigLimiterFunction.apply(orig_x, coeffs, new_direction,
|
||||
self.channel_dim, self.scale)
|
||||
else:
|
||||
# let self.cur_prob exponentially approach self.min_prob, as
|
||||
# long as the constraint is inactive.
|
||||
self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
|
||||
return orig_x
|
||||
|
||||
|
||||
def _set_direction(self,
|
||||
direction: Tensor):
|
||||
"""
|
||||
Sets self.max_eig_direction to a normalized version of `direction`
|
||||
"""
|
||||
direction = direction.detach()
|
||||
direction = direction / direction.norm()
|
||||
direction_sum = direction.sum().item()
|
||||
if direction_sum - direction_sum == 0: # no inf/nan
|
||||
self.max_eig_direction[:] = direction
|
||||
else:
|
||||
logging.info(f"Warning: sum of direction in MaxEig is {direction_sum}, "
|
||||
"num_channels={self.num_channels}, channel_dim={self.channel_dim}")
|
||||
|
||||
|
||||
def _find_direction_coeffs(self,
|
||||
x: Tensor,
|
||||
prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
Figure out (an approximation to) the proportion of the variance of a set of
|
||||
feature vectors that can be attributed to the top eigen-direction.
|
||||
Args:
|
||||
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
|
||||
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
|
||||
of the top eigen-direction, or a random direction if this is the first
|
||||
iteration. Does not have to be normalized, but should be nonzero.
|
||||
|
||||
Returns: (cur_direction, coeffs), where:
|
||||
cur_direction: a Tensor of shape (num_channels,) that is the current
|
||||
estimate of the top eigen-direction.
|
||||
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
|
||||
approximately minimizes, (x - coeffs * cur_direction).norm()
|
||||
"""
|
||||
(num_frames, num_channels) = x.shape
|
||||
assert num_channels > 1 and num_frames > 1
|
||||
assert prev_direction.shape == (num_channels,)
|
||||
# `coeffs` are the coefficients of `prev_direction` in x.
|
||||
# actually represent the coeffs up to a constant positive factor.
|
||||
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
|
||||
cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20)
|
||||
return cur_direction, coeffs
|
||||
|
||||
|
||||
|
||||
|
||||
class DoubleSwishFunction(torch.autograd.Function):
|
||||
"""
|
||||
@ -460,7 +564,8 @@ class DoubleSwish(torch.nn.Module):
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
|
||||
def _test_max_eig_limiter():
|
||||
|
||||
def _test_max_eig():
|
||||
|
||||
for proportion in [0.1, 0.5, 10.0]:
|
||||
logging.info(f"proportion = {proportion}")
|
||||
@ -471,15 +576,15 @@ def _test_max_eig_limiter():
|
||||
|
||||
x.requires_grad = True
|
||||
|
||||
y, new_direction = MaxEigLimiterFunction.apply(x, direction,
|
||||
1, # channel_dim
|
||||
True, # subtract_mean
|
||||
0.5, # max_variance_proportion
|
||||
0.1, # grad_scale
|
||||
)
|
||||
num_channels = 128
|
||||
m = MaxEig(num_channels,
|
||||
1, # channel_dim
|
||||
0.5, # max_var_per_eig
|
||||
scale=0.1) # grad_scale
|
||||
|
||||
cosine = (new_direction * direction).sum() / (new_direction.norm() * direction.norm())
|
||||
logging.info(f"Direction cosine = {cosine}")
|
||||
|
||||
for _ in range(4):
|
||||
y = m(x)
|
||||
|
||||
y_grad = torch.randn_like(x)
|
||||
y.backward(gradient=y_grad)
|
||||
@ -494,16 +599,17 @@ def _test_max_eig_limiter():
|
||||
def _test_activation_balancer_sign():
|
||||
probs = torch.arange(0, 1, 0.01)
|
||||
N = 1000
|
||||
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
||||
x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = ActivationBalancer(
|
||||
probs.numel(),
|
||||
channel_dim=0,
|
||||
min_positive=0.05,
|
||||
max_positive=0.98,
|
||||
max_positive=0.95,
|
||||
max_factor=0.2,
|
||||
min_abs=0.0,
|
||||
prob=1.0,
|
||||
)
|
||||
|
||||
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
||||
@ -531,6 +637,7 @@ def _test_activation_balancer_magnitude():
|
||||
max_factor=0.2,
|
||||
min_abs=0.2,
|
||||
max_abs=0.8,
|
||||
prob=1.0,
|
||||
)
|
||||
|
||||
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
||||
@ -571,7 +678,7 @@ if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_max_eig_limiter()
|
||||
_test_max_eig()
|
||||
_test_activation_balancer_sign()
|
||||
_test_activation_balancer_magnitude()
|
||||
_test_basic_norm()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user