Add max_var_per_eig in self-attn

This commit is contained in:
Daniel Povey 2022-09-18 21:22:01 +08:00
parent 76031a7c1d
commit 0f567e27a5
2 changed files with 252 additions and 9 deletions

View File

@ -173,7 +173,8 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
ActivationBalancer(channel_dim=-1, max_abs=10.0),
ActivationBalancer(dim_feedforward,
channel_dim=-1, max_abs=10.0),
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model,
@ -182,7 +183,8 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
ActivationBalancer(channel_dim=-1, max_abs=10.0),
ActivationBalancer(dim_feedforward,
channel_dim=-1, max_abs=10.0),
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model,
@ -196,7 +198,7 @@ class ConformerEncoderLayer(nn.Module):
# try to ensure the output is close to zero-mean (or at least, zero-median).
self.balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
d_model, channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
)
self.dropout = nn.Dropout(dropout)
@ -464,8 +466,11 @@ class RelPositionMultiheadAttention(nn.Module):
), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0)
self.proj_balancer = ActivationBalancer(channel_dim=-1, max_abs=10.0,
self.in_balancer = ActivationBalancer(3 * embed_dim,
channel_dim=-1, max_abs=5.0,
max_var_per_eig=0.1)
self.proj_balancer = ActivationBalancer(embed_dim,
channel_dim=-1, max_abs=10.0,
min_positive=0.0, max_positive=1.0)
self.out_proj = ScaledLinear(
embed_dim, embed_dim, bias=True, initial_scale=0.5
@ -901,6 +906,7 @@ class ConvolutionModule(nn.Module):
# it will be in a better position to start learning something, i.e. to latch onto
# the correct range.
self.deriv_balancer1 = ActivationBalancer(
2 * channels,
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
)
@ -915,7 +921,7 @@ class ConvolutionModule(nn.Module):
)
self.deriv_balancer2 = ActivationBalancer(
channel_dim=1, min_positive=0.05, max_positive=1.0
channels, channel_dim=1, min_positive=0.05, max_positive=1.0
)
self.activation = DoubleSwish()
@ -1001,7 +1007,8 @@ class Conv2dSubsampling(nn.Module):
kernel_size=3,
padding=1,
),
ActivationBalancer(channel_dim=1),
ActivationBalancer(layer1_channels,
channel_dim=1),
DoubleSwish(),
nn.Conv2d(
in_channels=layer1_channels,
@ -1009,7 +1016,8 @@ class Conv2dSubsampling(nn.Module):
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
ActivationBalancer(layer2_channels,
channel_dim=1),
DoubleSwish(),
nn.Conv2d(
in_channels=layer2_channels,
@ -1017,7 +1025,8 @@ class Conv2dSubsampling(nn.Module):
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
ActivationBalancer(layer3_channels,
channel_dim=1),
DoubleSwish(),
)
out_height = (((in_channels - 1) // 2 - 1) // 2)
@ -1028,6 +1037,7 @@ class Conv2dSubsampling(nn.Module):
self.out_norm = BasicNorm(out_channels, learn_eps=False)
# constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(
out_channels,
channel_dim=-1, min_positive=0.45, max_positive=0.55
)

View File

@ -114,6 +114,173 @@ class ActivationBalancerFunction(torch.autograd.Function):
return x_grad - neg_delta_grad, None, None, None, None, None, None
def find_direction_coeffs(x: Tensor,
prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""
Figure out (an approximation to) the proportion of the variance of a set of
feature vectors that can be attributed to the top eigen-direction.
Args:
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
of the top eigen-direction, or a random direction if this is the first
iteration. Does not have to be normalized, but should be nonzero.
Returns: (cur_direction, coeffs), where:
cur_direction: a Tensor of shape (num_channels,) that is the current
estimate of the top eigen-direction.
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
approximately minimizes, (x - coeffs * cur_direction).norm()
"""
(num_frames, num_channels) = x.shape
assert num_channels > 1 and num_frames > 1
assert prev_direction.shape == (num_channels,)
# `coeffs` are the coefficients of `prev_direction` in x.
# actually represent the coeffs up to a constant positive factor.
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20)
return cur_direction, coeffs
def get_max_eig_proportion(x: Tensor,
prev_direction: Tensor,
subtract_mean: bool) -> Tuple[Tensor, Tensor]:
"""
Figure out (an approximation to) the proportion of the variance of a set of
feature vectors that can be attributed to the top eigen-direction.
Args:
x: a Tensor of shape (*, num_channels). There must be more than one frame,
i.e. x.numel() // num_channels > 1.
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
of the top eigen-direction, or a random direction if this is the first
iteration. Expected to be without gradient. Does not have to be
normalized.
subtract_mean: if True, we will first subtract the mean of x, over the
frames. Suggest to make this true in most circumstances.
Returns: (cur_direction, max_proportion), where:
cur_direction: a Tensor of shape (num_channels,) that is the current
estimate of the top eigen-direction. Detached / not intended to be
differentiable.
proportion: a scalar Tensor containing the proportion of the variance
of the input that is in direction `cur_direction`. This is with
gradient, that can be propagated back to x.
"""
num_channels = x.shape[-1]
assert prev_direction.shape == (num_channels,)
x = x.reshape(-1, num_channels)
if subtract_mean:
x = x - x.mean(dim=0)
with torch.no_grad():
cur_norm = prev_direction.norm()
prev_direction = prev_direction / cur_norm
is_ok = (cur_norm / cur_norm == 1.0)
# if there was a problem like NaN or inf, restart. this should be very rare.
prev_direction = torch.where(is_ok.unsqueeze(-1).expand(prev_direction.shape),
prev_direction,
torch.randn_like(prev_direction) * (num_channels ** -0.5))
# `coeffs` are the coefficients of `prev_direction` in x.
coeffs = (x * prev_direction).sum(dim=1, keepdim=True)
x_norm = x.norm()
x_coeffs1_norm = (x - coeffs * prev_direction).norm()
with torch.no_grad():
cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20)
x_coeffs2_norm = (x - coeffs * cur_direction).norm()
# for the returned direction interpolate with prev_direction so that
# even if x == 0, we get a nonzero new direction.
ans_direction = 0.5 * (prev_direction + cur_direction)
x_sumsq = (x**2).sum() + 1.0e-20
x_remaining_sumsq = ((x - coeffs * cur_direction) ** 2).sum() + 1.0e-20
proportion = (x - x_remaining_sumsq) / x_sumsq
return (ans_direction, proportion)
print(f"x_norm={x_norm}, x_coeffs1_norm={x_coeffs1_norm}, x_coeffs2_norm={x_coeffs2_norm}")
class MaxEigLimiterFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
direction: Tensor,
channel_dim: int,
prob: float,
subtract_mean: bool,
max_variance_proportion: float,
grad_scale: float) -> Tuple[Tensor, Tensor]:
if random.random() > prob:
return x, direction
eps = 1.0e-20
num_channels = x.shape[channel_dim]
assert max_variance_proportion > 1.0 / num_channels
orig_x = x
x = x.transpose(channel_dim, -1).reshape(-1, num_channels)
if subtract_mean:
x = x - x.mean(dim=0)
new_direction, coeffs = find_direction_coeffs(x, direction)
x_var = (x**2).sum()
x_residual = x - coeffs * new_direction
x_residual_var = (x_residual**2).sum()
# `variance_proportion` is the proportion of the variance accounted for
# by the top eigen-direction.
variance_proportion = (x_var - x_residual_var) / x_var
ans_direction = direction + new_direction # ensure nonzero even if x == 0
ans_direction = ans_direction / ans_direction.norm()
logging.info(f"variance_proportion = {variance_proportion.item()}")
# Caution: this causes a CUDA sync, which is not ideal.
if variance_proportion >= max_variance_proportion:
ctx.channel_dim = channel_dim
ctx.subtract_mean = subtract_mean
ctx.grad_scale = grad_scale
ctx.save_for_backward(orig_x.detach(),
coeffs.detach(),
new_direction.detach())
return orig_x, ans_direction
@staticmethod
def backward(ctx, x_grad, *args):
# the *args is all the other derivs, which should be None or zero.
if not hasattr(ctx, 'channel_dim'):
# the top eig's proportion of the variance was below the threshold.
return x_grad, None, None, None, None, None, None
with torch.enable_grad():
(x_orig, coeffs, new_direction) = ctx.saved_tensors
x_orig.requires_grad = True
num_channels = x_orig.shape[ctx.channel_dim]
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
new_direction.requires_grad = False
if ctx.subtract_mean:
x = x - x.mean(dim=0)
x_var = (x**2).sum()
x_residual = x - coeffs * new_direction
x_residual_var = (x_residual**2).sum()
# `variance_proportion` is the proportion of the variance accounted for
# by the top eigen-direction. This is to be minimized.
variance_proportion = (x_var - x_residual_var) / x_var
variance_proportion.backward()
x_orig_grad = x_orig.grad
x_extra_grad = x_orig.grad * x_orig.grad.norm() / (x_orig_grad.norm() + 1.0e-20)
return x_grad + x_extra_grad, None, None, None, None, None, None
class BasicNorm(torch.nn.Module):
@ -236,6 +403,7 @@ class ActivationBalancer(torch.nn.Module):
Args:
num_channels: the number of channels
channel_dim: the dimension/axis corresponding to the channel, e.g.
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
min_positive: the minimum, per channel, of the proportion of the time
@ -252,29 +420,56 @@ class ActivationBalancer(torch.nn.Module):
max_abs: the maximum average-absolute-value difference from the mean
value per channel, which we allow, before we start to modify
the derivatives to prevent this.
max_var_per_eig: the maximum proportion of the variance of the
features/channels, after mean subtraction, that can come from
any given eigenvalue.
"""
def __init__(
self,
num_channels: int,
channel_dim: int,
min_positive: float = 0.05,
max_positive: float = 0.95,
max_factor: float = 0.01,
min_abs: float = 0.2,
max_abs: float = 100.0,
max_var_per_eig: float = 0.0,
):
super(ActivationBalancer, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.min_positive = min_positive
self.max_positive = max_positive
self.max_factor = max_factor
self.min_abs = min_abs
self.max_abs = max_abs
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
self.max_var_per_eig = max_var_per_eig
if max_var_per_eig > 0.0:
with torch.no_grad():
direction = torch.randn(num_channels)
direction = direction / direction.norm()
self.register_buffer('max_eig_direction', direction)
else:
self.max_eig_direction = None
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting():
return x
if self.max_var_per_eig > 0:
x, new_direction = MaxEigLimiterFunction.apply(
x, self.max_eig_direction,
self.channel_dim,
0.1, # prob
True, # subtract_mean
self.max_var_per_eig,
self.max_factor,
)
self.max_eig_direction[:] = new_direction
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
@ -326,6 +521,35 @@ class DoubleSwish(torch.nn.Module):
return DoubleSwishFunction.apply(x)
def _test_max_eig_limiter():
for proportion in [0.1, 0.5, 10.0]:
logging.info(f"proportion = {proportion}")
x = torch.randn(100, 128)
direction = torch.randn(128)
coeffs = torch.randn(100, 1)
x += proportion * direction * coeffs
x.requires_grad = True
y, new_direction = MaxEigLimiterFunction.apply(x, direction,
1, # channel_dim
1.0, # prob
True, # subtract_mean
0.5, # max_variance_proportion
0.1, # grad_scale
)
cosine = (new_direction * direction).sum() / (new_direction.norm() * direction.norm())
logging.info(f"Direction cosine = {cosine}")
y_grad = torch.randn_like(x)
y.backward(gradient=y_grad)
if proportion < 0.2:
assert torch.allclose(x.grad, y_grad)
elif proportion > 1.0:
assert not torch.allclose(x.grad, y_grad)
@ -336,6 +560,7 @@ def _test_activation_balancer_sign():
x = x.detach()
x.requires_grad = True
m = ActivationBalancer(
probs.numel(),
channel_dim=0,
min_positive=0.05,
max_positive=0.98,
@ -361,6 +586,7 @@ def _test_activation_balancer_magnitude():
x = x.detach()
x.requires_grad = True
m = ActivationBalancer(
magnitudes.numel(),
channel_dim=0,
min_positive=0.0,
max_positive=1.0,
@ -402,10 +628,17 @@ def _test_double_swish_deriv():
torch.autograd.gradcheck(m, x)
def _test_get_max_eig_proportion():
x = torch.randn(100, 128)
d = torch.randn(128) * (128 ** -0.5)
get_max_eig_proportion(x, d, True)
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_max_eig_limiter()
_test_get_max_eig_proportion()
_test_activation_balancer_sign()
_test_activation_balancer_magnitude()
_test_basic_norm()