Make ActivationBalancer and MaxEig more efficient.

This commit is contained in:
Daniel Povey 2022-10-12 18:44:52 +08:00
parent 1825336841
commit 12323025d7
2 changed files with 309 additions and 197 deletions

View File

@ -27,6 +27,7 @@ from encoder_interface import EncoderInterface
from scaling import ( from scaling import (
ActivationBalancer, ActivationBalancer,
BasicNorm, BasicNorm,
MaxEig,
DoubleSwish, DoubleSwish,
ScaledConv1d, ScaledConv1d,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
@ -293,8 +294,11 @@ class ConformerEncoderLayer(nn.Module):
d_model, channel_dim=-1, d_model, channel_dim=-1,
min_positive=0.45, max_positive=0.55, min_positive=0.45, max_positive=0.55,
max_abs=6.0, max_abs=6.0,
max_var_per_eig=0.2,
) )
self.max_eig = MaxEig(
d_model, channel_dim=-1,
)
def forward( def forward(
self, self,
@ -350,7 +354,7 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.feed_forward3(src) src = src + self.feed_forward3(src)
src = self.norm_final(self.balancer(src)) src = self.norm_final(self.max_eig(self.balancer(src)))
delta = src - src_orig delta = src - src_orig
bypass_scale = self.bypass_scale bypass_scale = self.bypass_scale
@ -838,8 +842,9 @@ class RelPositionMultiheadAttention(nn.Module):
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True) self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
self.in_balancer = ActivationBalancer(3 * embed_dim // 2, self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
channel_dim=-1, max_abs=5.0, channel_dim=-1, max_abs=5.0)
max_var_per_eig=0.2) self.in_max_eig = MaxEig(3 * embed_dim // 2,
channel_dim=-1)
self.proj_balancer = ActivationBalancer(embed_dim // 2, self.proj_balancer = ActivationBalancer(embed_dim // 2,
channel_dim=-1, max_abs=10.0, channel_dim=-1, max_abs=10.0,
min_positive=0.0, max_positive=1.0) min_positive=0.0, max_positive=1.0)
@ -915,7 +920,7 @@ class RelPositionMultiheadAttention(nn.Module):
before softmax. before softmax.
""" """
x, weights, scores = self.multi_head_attention_forward( x, weights, scores = self.multi_head_attention_forward(
self.in_balancer(self.in_proj(x)), self.in_max_eig(self.in_balancer(self.in_proj(x))),
pos_emb, pos_emb,
None if attn_scores_in is None else torch.matmul(attn_scores_in, self.attn_scores_proj_in), None if attn_scores_in is None else torch.matmul(attn_scores_in, self.attn_scores_proj_in),
self.embed_dim, self.embed_dim,

View File

@ -29,120 +29,35 @@ from torch import Tensor
from torch.nn import Embedding as ScaledEmbedding from torch.nn import Embedding as ScaledEmbedding
def _ntuple(n):
def parse(x):
if isinstance(x, collections.Iterable):
return x
return tuple(repeat(x, n))
return parse
_single = _ntuple(1)
_pair = _ntuple(2)
class ActivationBalancerFunction(torch.autograd.Function): class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
x: Tensor, x: Tensor,
channel_dim: int, sign_factor: Tensor,
min_positive: float, # e.g. 0.05 scale_factor: Tensor,
max_positive: float, # e.g. 0.95 channel_dim: int,
max_factor: float, # e.g. 0.01
min_abs: float, # e.g. 0.2
max_abs: float, # e.g. 100.0
) -> Tensor: ) -> Tensor:
if x.requires_grad: if channel_dim < 0:
if channel_dim < 0: channel_dim += x.ndim
channel_dim += x.ndim ctx.channel_dim = channel_dim
sum_dims = [d for d in range(x.ndim) if d != channel_dim] xgt0 = (x > 0)
x_normalized = x - torch.mean(x, dim=sum_dims, keepdim=True) ctx.save_for_backward(xgt0, sign_factor, scale_factor)
xgtmean = (x_normalized > 0)
proportion_positive = torch.mean(
(x > 0).to(x.dtype), dim=sum_dims, keepdim=True
)
factor1 = (
(min_positive - proportion_positive).relu()
* (max_factor / min_positive)
if min_positive != 0.0
else 0.0
)
factor2 = (
(proportion_positive - max_positive).relu()
* (max_factor / (max_positive - 1.0))
if max_positive != 1.0
else 0.0
)
# `factor` is a tensor of shape something like (1, 1, num_channels,
# 1), containing elements between -1 and 1 that are zero if the
# proportion of positive features is between min_positive and
# max_positive, max_factor if proportion==0.0 (all features are negative),
# and -max_factor if proportion==1.0 (all features are positive). It is
# an amount per channel by which we'll modify the gradient; the sign
# of modifying the gradient will depend on the sign of the gradient.
factor = factor1 + factor2
if isinstance(factor, float):
factor = torch.zeros_like(proportion_positive)
mean_abs = torch.mean(x_normalized.abs(), dim=sum_dims, keepdim=True)
below_threshold = mean_abs < min_abs
above_threshold = mean_abs > max_abs
ctx.save_for_backward(
factor, xgtmean, below_threshold, above_threshold
)
ctx.max_factor = max_factor
ctx.sum_dims = sum_dims
return x return x
@staticmethod @staticmethod
def backward( def backward(
ctx, x_grad: Tensor ctx, x_grad: Tensor
) -> Tuple[Tensor, None, None, None, None, None, None]: ) -> Tuple[Tensor, None, None, None]:
factor, xgtmean, below_threshold, above_threshold = ctx.saved_tensors xgt0, sign_factor, scale_factor = ctx.saved_tensors
dtype = x_grad.dtype for _ in range(ctx.channel_dim, x_grad.ndim - 1):
scale_factor = ( sign_factor = sign_factor.unsqueeze(-1)
(below_threshold.to(dtype) - above_threshold.to(dtype)) scale_factor = scale_factor.unsqueeze(-1)
* (xgtmean.to(dtype) - 0.5)
* (ctx.max_factor * 2.0)
)
neg_delta_grad = x_grad.abs() * (factor + scale_factor) factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
return x_grad - neg_delta_grad, None, None, None, None, None, None neg_delta_grad = x_grad.abs() * factor
return x_grad - neg_delta_grad, None, None, None,
def find_direction_coeffs(x: Tensor,
prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""
Figure out (an approximation to) the proportion of the variance of a set of
feature vectors that can be attributed to the top eigen-direction.
Args:
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
of the top eigen-direction, or a random direction if this is the first
iteration. Does not have to be normalized, but should be nonzero.
Returns: (cur_direction, coeffs), where:
cur_direction: a Tensor of shape (num_channels,) that is the current
estimate of the top eigen-direction.
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
approximately minimizes, (x - coeffs * cur_direction).norm()
"""
(num_frames, num_channels) = x.shape
assert num_channels > 1 and num_frames > 1
assert prev_direction.shape == (num_channels,)
# `coeffs` are the coefficients of `prev_direction` in x.
# actually represent the coeffs up to a constant positive factor.
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20)
return cur_direction, coeffs
@ -152,57 +67,27 @@ class MaxEigLimiterFunction(torch.autograd.Function):
def forward( def forward(
ctx, ctx,
x: Tensor, x: Tensor,
coeffs: Tensor,
direction: Tensor, direction: Tensor,
channel_dim: int, channel_dim: int,
subtract_mean: bool, grad_scale: float) -> Tensor:
max_variance_proportion: float, ctx.channel_dim = channel_dim
grad_scale: float) -> Tuple[Tensor, Tensor]: ctx.grad_scale = grad_scale
eps = 1.0e-20 ctx.save_for_backward(x.detach(),
num_channels = x.shape[channel_dim] coeffs.detach(),
assert max_variance_proportion > 1.0 / num_channels direction.detach())
orig_x = x return x
x = x.transpose(channel_dim, -1).reshape(-1, num_channels)
if subtract_mean:
x = x - x.mean(dim=0)
new_direction, coeffs = find_direction_coeffs(x, direction)
x_var = (x**2).mean()
x_residual = x - coeffs * new_direction
x_residual_var = (x_residual**2).mean()
# `variance_proportion` is the proportion of the variance accounted for
# by the top eigen-direction.
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
ans_direction = direction + new_direction # ensure nonzero even if x == 0
ans_direction = ans_direction / ans_direction.norm()
if random.random() < 0.0005:
logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(x.shape)}")
# Caution: this causes a CUDA sync, which is not ideal.
if variance_proportion >= max_variance_proportion:
ctx.channel_dim = channel_dim
ctx.subtract_mean = subtract_mean
ctx.grad_scale = grad_scale
ctx.save_for_backward(orig_x.detach(),
coeffs.detach(),
new_direction.detach())
return orig_x, ans_direction
@staticmethod @staticmethod
def backward(ctx, x_grad, *args): def backward(ctx, x_grad, *args):
# the *args is all the other derivs, which should be None or zero.
if not hasattr(ctx, 'channel_dim'):
# the top eig's proportion of the variance was below the threshold.
return x_grad, None, None, None, None, None, None
with torch.enable_grad(): with torch.enable_grad():
(x_orig, coeffs, new_direction) = ctx.saved_tensors (x_orig, coeffs, new_direction) = ctx.saved_tensors
x_orig.requires_grad = True x_orig.requires_grad = True
num_channels = x_orig.shape[ctx.channel_dim] num_channels = x_orig.shape[ctx.channel_dim]
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
new_direction.requires_grad = False new_direction.requires_grad = False
if ctx.subtract_mean: x = x - x.mean(dim=0)
x = x - x.mean(dim=0)
x_var = (x ** 2).mean() x_var = (x ** 2).mean()
x_residual = x - coeffs * new_direction x_residual = x - coeffs * new_direction
x_residual_var = (x_residual ** 2).mean() x_residual_var = (x_residual ** 2).mean()
@ -212,7 +97,7 @@ class MaxEigLimiterFunction(torch.autograd.Function):
variance_proportion.backward() variance_proportion.backward()
x_orig_grad = x_orig.grad x_orig_grad = x_orig.grad
x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20)
return x_grad + x_extra_grad.detach(), None, None, None, None, None, None return x_grad + x_extra_grad.detach(), None, None, None, None
class BasicNorm(torch.nn.Module): class BasicNorm(torch.nn.Module):
@ -352,11 +237,15 @@ class ActivationBalancer(torch.nn.Module):
max_abs: the maximum average-absolute-value difference from the mean max_abs: the maximum average-absolute-value difference from the mean
value per channel, which we allow, before we start to modify value per channel, which we allow, before we start to modify
the derivatives to prevent this. the derivatives to prevent this.
max_var_per_eig: the maximum proportion of the variance of the beta: a constant used in decaying stats for the {min,max}_positive and
features/channels, after mean subtraction, that can come from {min,max}_abs constraints. Likely not critical.
any given eigenvalue. prob: determines the probability with which we modify the
gradients for the {min,max}_positive and {min,max}_abs constraints,
on each forward(). This is done randomly to prevent all layers
from doing it at the same time.
stats_period: the periodicity with which we update the statistics on
the activations.
""" """
def __init__( def __init__(
self, self,
num_channels: int, num_channels: int,
@ -367,6 +256,9 @@ class ActivationBalancer(torch.nn.Module):
min_abs: float = 0.2, min_abs: float = 0.2,
max_abs: float = 100.0, max_abs: float = 100.0,
max_var_per_eig: float = 0.0, max_var_per_eig: float = 0.0,
beta: float = 0.75,
prob: float = 0.25,
stats_period: int = 10,
): ):
super(ActivationBalancer, self).__init__() super(ActivationBalancer, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
@ -376,49 +268,261 @@ class ActivationBalancer(torch.nn.Module):
self.max_factor = max_factor self.max_factor = max_factor
self.min_abs = min_abs self.min_abs = min_abs
self.max_abs = max_abs self.max_abs = max_abs
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels self.beta = beta
self.max_var_per_eig = max_var_per_eig self.prob = prob
if max_var_per_eig > 0.0: self.stats_period = stats_period
with torch.no_grad():
# arbitrary.. would use randn() but want to leave the rest of the model's # count measures how many times the forward() function has been called.
# random parameters unchanged for comparison self.count = 0
direction = torch.arange(num_channels).to(torch.float)
direction = direction / direction.norm() # the mean of the absolute value of the data per channel
self.register_buffer('max_eig_direction', direction) self.register_buffer('abs_mean', torch.zeros(num_channels))
else:
self.max_eig_direction = None # the proportion of activations that are positive, per channel.
self.register_buffer('proportion_positive', torch.zeros(num_channels))
# `factors` contains two buffers of shape (num_channels,).
# `sign_factor` is an expression that will be used to scale the
# gradients in backprop; it will be 0 if the max_positive and min_positive
# contstraints are satisfied.
# `scale_factor` is an expression that will be used to encourage the
# data to satisfy our min_abs and max_abs constraints; it will be zero if
# all constraints are satisfied.
self.register_buffer('factors', torch.zeros(2, num_channels))
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting() or not x.requires_grad:
return x return x
max_eig_prob = 0.25 count = self.count
if self.max_var_per_eig > 0 and random.random() < max_eig_prob: self.count += 1
with torch.cuda.amp.autocast(enabled=False):
x, new_direction = MaxEigLimiterFunction.apply(
x, self.max_eig_direction,
self.channel_dim,
True, # subtract_mean
self.max_var_per_eig,
self.max_factor / max_eig_prob,
)
self.max_eig_direction[:] = new_direction.detach()
balance_prob = 0.25 if count % self.stats_period == 0:
if random.random() < balance_prob: self._update_stats(x, count)
if random.random() < self.prob:
# The .clone() is in case the forward() gets called multiple times befor
factors = self.factors.clone()
sign_factor = factors[0]
scale_factor = factors[1]
return ActivationBalancerFunction.apply( return ActivationBalancerFunction.apply(
x, x, sign_factor, scale_factor, self.channel_dim,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor / balance_prob,
self.min_abs,
self.max_abs,
) )
else: else:
return x return x
def _update_stats(self,
x: Tensor,
count: int):
"""
Updates some statistics that we maintain, describing the average activations per
channel.
"""
with torch.no_grad():
sum_dims = [d for d in range(x.ndim) if d != self.channel_dim]
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
# the random.random() thing is to split the difference if x is zero,
# between treating it positive or negative
proportion_positive = torch.mean(
((x > 0) if random.random() < 0.5 else (x >= 0)).to(torch.float32), dim = sum_dims,
)
def filter_inf_nan(y):
mask = (y - y != 0)
y.masked_fill_(mask, 0.0)
filter_inf_nan(x_abs_mean)
beta = self.beta if count > 0 else 0.0
self.abs_mean.mul_(beta).add_(x_abs_mean, alpha=(1-beta))
self.proportion_positive.mul_(beta).add_(proportion_positive, alpha=(1-beta))
max_factor = self.max_factor / self.prob
min_positive = self.min_positive
max_positive = self.max_positive
if min_positive == 0.0:
factor1 = 0.0
else:
# 0 if self.proportion_positive >= min_positive, else can be
# as large as max_factor.
factor1 = ((min_positive - self.proportion_positive).relu() *
(max_factor / min_positive))
if max_positive == 1.0:
factor2 = 0.0
else:
# 0 if self.proportion_positive <= max_positive, else can be
# as large as -max_factor.
factor2 = ((self.proportion_positive - max_positive).relu()
* (max_factor / (max_positive - 1.0)))
sign_factor = self.factors[0]
scale_factor = self.factors[1]
sign_factor[:] = factor1 + factor2
# the factor of 2.0 below is just to cancel out a factor of 0.5 that gets introduced when, in
# the backprop, we do (xgt0.to(dtype) - 0.5).
#
# scale_factor_scale, on the other hand, is a heuristically chosen value between 0 and 1,
# that we use to make the gradient changes from the 'scale' constraints (min_abs/max_abs)
# less strong than those from the sign constraints.
#
# This is to get rid of a pathology that can happen if, for instance, a
# channel is always positive but is too small (max_positive and min_abs constraints both
# violated). If scale_factor_scale were equal to 1.0, then the gradient changes from the
# min_positive constraint (trying to make the activation more negative) and from the
# min_abs constraint (trying to make the activation more positive) would exactly cancel.
# Instead we make the min_positive constraint stronger, so it first makes the value
# sometimes negative, and only when that is satisfied, can deal with the absolute-value
# constraint.
scale_factor_scale = 0.5
below_threshold = (self.abs_mean < self.min_abs)
above_threshold = (self.abs_mean > self.max_abs)
scale_factor[:] = ((below_threshold.to(torch.float32) -
above_threshold.to(torch.float32))
* (max_factor * (2.0 * scale_factor_scale)))
class MaxEig(torch.nn.Module):
"""
Modifies the backpropped derivatives of a function to try to discourage
that any given direction in activation space accounts for more than
a specified proportion of the covariance (e.g. 0.2).
Args:
num_channels: the number of channels
channel_dim: the dimension/axis corresponding to the channel, e.g.
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
max_var_per_eig: the maximum proportion of the variance of the
features/channels, after mean subtraction, that can come from
any given eigenvalue.
min_prob: the minimum probability with which we apply this during any invocation
of forward(), assuming last time we applied the constraint it was
not active; supplied for speed.
scale: determines the scale with which we modify the gradients, relative
to the existing / unmodified gradients
"""
def __init__(
self,
num_channels: int,
channel_dim: int,
max_var_per_eig: float = 0.2,
min_prob: float = 0.01,
scale: float = 0.01,
):
super(MaxEig, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.scale = scale
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
self.max_var_per_eig = max_var_per_eig
# we figure out the dominant direction using the power method: starting with
# a random vector, keep multiplying by the covariance and renormalizing.
with torch.no_grad():
# arbitrary.. would use randn() but want to leave the rest of the model's
# random parameters unchanged for comparison
direction = torch.arange(num_channels).to(torch.float)
direction = direction / direction.norm()
self.register_buffer('max_eig_direction', direction)
self.min_prob = min_prob
# cur_prob is the current probability we'll use to apply the ActivationBalancer.
# We'll regress this towards prob, each tiem we try to apply it and it is not
# active.
self.cur_prob = 1.0
def forward(self, x: Tensor) -> Tensor:
if (torch.jit.is_scripting() or
self.max_var_per_eig <= 0 or
random.random() > self.cur_prob):
return x
with torch.cuda.amp.autocast(enabled=False):
eps = 1.0e-20
assert x.dtype != torch.float16
orig_x = x
with torch.no_grad():
x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels)
x = x - x.mean(dim=0)
new_direction, coeffs = self._find_direction_coeffs(x, self.max_eig_direction)
x_var = (x**2).mean()
x_residual = x - coeffs * new_direction
x_residual_var = (x_residual**2).mean()
# `variance_proportion` is the proportion of the variance accounted for
# by the top eigen-direction.
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
# ensure new direction is nonzero even if x == 0, by including `direction`.
self._set_direction(0.1 * self.max_eig_direction + new_direction)
if random.random() < 0.0005 or __name__ == "__main__":
logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}")
if variance_proportion >= self.max_var_per_eig:
# The constraint is active. Note, we should quite rarely
# reach here, only near the beginning of training if we are
# starting to diverge, should this constraint be active.
cur_prob = self.cur_prob
self.cur_prob = 1.0 # next time, do the update with probability 1.0.
return MaxEigLimiterFunction.apply(orig_x, coeffs, new_direction,
self.channel_dim, self.scale)
else:
# let self.cur_prob exponentially approach self.min_prob, as
# long as the constraint is inactive.
self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
return orig_x
def _set_direction(self,
direction: Tensor):
"""
Sets self.max_eig_direction to a normalized version of `direction`
"""
direction = direction.detach()
direction = direction / direction.norm()
direction_sum = direction.sum().item()
if direction_sum - direction_sum == 0: # no inf/nan
self.max_eig_direction[:] = direction
else:
logging.info(f"Warning: sum of direction in MaxEig is {direction_sum}, "
"num_channels={self.num_channels}, channel_dim={self.channel_dim}")
def _find_direction_coeffs(self,
x: Tensor,
prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""
Figure out (an approximation to) the proportion of the variance of a set of
feature vectors that can be attributed to the top eigen-direction.
Args:
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
of the top eigen-direction, or a random direction if this is the first
iteration. Does not have to be normalized, but should be nonzero.
Returns: (cur_direction, coeffs), where:
cur_direction: a Tensor of shape (num_channels,) that is the current
estimate of the top eigen-direction.
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
approximately minimizes, (x - coeffs * cur_direction).norm()
"""
(num_frames, num_channels) = x.shape
assert num_channels > 1 and num_frames > 1
assert prev_direction.shape == (num_channels,)
# `coeffs` are the coefficients of `prev_direction` in x.
# actually represent the coeffs up to a constant positive factor.
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20)
return cur_direction, coeffs
class DoubleSwishFunction(torch.autograd.Function): class DoubleSwishFunction(torch.autograd.Function):
""" """
@ -460,7 +564,8 @@ class DoubleSwish(torch.nn.Module):
return DoubleSwishFunction.apply(x) return DoubleSwishFunction.apply(x)
def _test_max_eig_limiter():
def _test_max_eig():
for proportion in [0.1, 0.5, 10.0]: for proportion in [0.1, 0.5, 10.0]:
logging.info(f"proportion = {proportion}") logging.info(f"proportion = {proportion}")
@ -471,15 +576,15 @@ def _test_max_eig_limiter():
x.requires_grad = True x.requires_grad = True
y, new_direction = MaxEigLimiterFunction.apply(x, direction, num_channels = 128
1, # channel_dim m = MaxEig(num_channels,
True, # subtract_mean 1, # channel_dim
0.5, # max_variance_proportion 0.5, # max_var_per_eig
0.1, # grad_scale scale=0.1) # grad_scale
)
cosine = (new_direction * direction).sum() / (new_direction.norm() * direction.norm())
logging.info(f"Direction cosine = {cosine}") for _ in range(4):
y = m(x)
y_grad = torch.randn_like(x) y_grad = torch.randn_like(x)
y.backward(gradient=y_grad) y.backward(gradient=y_grad)
@ -494,16 +599,17 @@ def _test_max_eig_limiter():
def _test_activation_balancer_sign(): def _test_activation_balancer_sign():
probs = torch.arange(0, 1, 0.01) probs = torch.arange(0, 1, 0.01)
N = 1000 N = 1000
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
m = ActivationBalancer( m = ActivationBalancer(
probs.numel(), probs.numel(),
channel_dim=0, channel_dim=0,
min_positive=0.05, min_positive=0.05,
max_positive=0.98, max_positive=0.95,
max_factor=0.2, max_factor=0.2,
min_abs=0.0, min_abs=0.0,
prob=1.0,
) )
y_grad = torch.sign(torch.randn(probs.numel(), N)) y_grad = torch.sign(torch.randn(probs.numel(), N))
@ -531,6 +637,7 @@ def _test_activation_balancer_magnitude():
max_factor=0.2, max_factor=0.2,
min_abs=0.2, min_abs=0.2,
max_abs=0.8, max_abs=0.8,
prob=1.0,
) )
y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
@ -571,7 +678,7 @@ if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
_test_max_eig_limiter() _test_max_eig()
_test_activation_balancer_sign() _test_activation_balancer_sign()
_test_activation_balancer_magnitude() _test_activation_balancer_magnitude()
_test_basic_norm() _test_basic_norm()