Reorganize Whiten() code; configs are not the same as before. Also remove MaxEig for self_attn module

This commit is contained in:
Daniel Povey 2022-10-15 23:20:18 +08:00
parent 9919a05612
commit fc728f2738
2 changed files with 215 additions and 143 deletions

View File

@ -31,6 +31,8 @@ from scaling import (
DoubleSwish,
ScaledConv1d,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
Whiten,
_diag,
)
from torch import Tensor, nn
@ -801,129 +803,6 @@ class RelPositionalEncoding(torch.nn.Module):
return self.dropout(pos_emb)
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
if x.ndim == 2:
return x.diag()
else:
(batch, dim, dim) = x.shape
x = x.reshape(batch, dim * dim)
x = x[:, ::dim+1]
assert x.shape == (batch, dim)
return x
class WhiteningPenaltyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
x: Tensor,
whitening_limit: float,
grad_scale: float) -> Tensor:
ctx.save_for_backward(x)
ctx.whitening_limit = whitening_limit
ctx.grad_scale = grad_scale
return x
@staticmethod
def backward(ctx,
x_grad: Tensor):
x_orig, = ctx.saved_tensors
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True
assert x_detached.ndim >= 3
x = x_detached.reshape(-1, x_detached.shape[-2],
x_detached.shape[-1]).transpose(0, 1)
(num_groups, num_frames, channels_per_group) = x.shape
# subtract the mean so we use the centered, not uncentered, covariance.
# My experience has been that when we "mess with the gradients" like this,
# it's better not do anything that tries to move the mean around, because
# that can easily cause instability.
x = x - x.mean(dim=1, keepdim=True)
# x_covar: (num_groups, channels_per_group, channels_per_group)
x_covar = torch.matmul(x.transpose(1, 2), x)
# normalize x_covar so that its average diagonal element is 1.
x_covar = x_covar / (_diag(x_covar).mean() + 1.0e-20)
# x_covar_sq: (num_groups, channels_per_group, channels_per_group).
# if the normalized x_covar were just `num_groups` copies of the
# identity matrix, x_covar_sq will have the same value. But
# in general, it will be larger than that.
x_covar_sq = torch.matmul(x_covar, x_covar)
metric = _diag(x_covar_sq).mean()
if random.random() < 0.005 or __name__ == "__main__":
logging.info(f"Whitening: num_groups={num_groups}, channels_per_group={channels_per_group}, "
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}")
(metric - ctx.whitening_limit).relu().backward()
penalty_grad = x_detached.grad
scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() /
(penalty_grad.norm() + 1.0e-20))
penalty_grad = penalty_grad * scale
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
class Whiten(nn.Module):
def __init__(
self,
whitening_limit: float,
prob: float,
grad_scale: float):
"""
Args:
num_groups: the number of groups to divide the input into before
whitening it. We will attempt to make the feature covariance
within each group, after mean subtraction, as "white" as possible
while having the same trace across all groups.
whitening_limit: a value greater than 1.0, that dictates how much
freedom we have to violate the constraints. 1.0 would mean perfectly
white, with exactly the same trace across groups; larger values
give more freedom. E.g. 2.0.
prob: the probability with which we apply this object (also affects
grad scale). e.g. 0.25
grad_scale: determines the scale on the gradient term from this object,
relative to the rest of the gradient on the attention weights;
will be divided by `prob`. e.g. 0.005
"""
super(Whiten, self).__init__()
assert whitening_limit >= 1
assert 0 < prob <= 1
assert grad_scale >= 0
self.whitening_limit = whitening_limit
self.prob = prob
self.grad_scale = grad_scale
def forward(self,
x: Tensor) -> Tensor:
"""
In the forward pass, this function just returns the input unmodified.
In the backward pass, it will modify the gradients to ensure that the
distribution in each group has close to (lambda times I) as the covariance
after mean subtraction, with the same lambda across groups.
For whitening_limit > 1, there will be more freedom to violate this
constraint.
Args:
x: the input of shape (*, num_groups, channels_per_group)
Returns:
x, unmodified. You should make sure
you use the returned value, or the graph will be freed
and nothing will happen in backprop.
"""
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
return x
else:
return WhiteningPenaltyFunction.apply(x,
self.whitening_limit,
self.grad_scale / self.prob)
class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding
@ -958,20 +837,20 @@ class RelPositionMultiheadAttention(nn.Module):
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
# self.whiten is applied on the values in forward()
self.whiten_values = Whiten(whitening_limit=1.1,
prob=1.0 if __name__ == "__main__" else 0.1,
grad_scale=0.0025)
# self.whiten_values is applied on the values in forward()
self.whiten_values = Whiten(num_groups=num_heads,
whitening_limit=1.1,
prob=(0.025, 0.25),
grad_scale=0.025)
# self.whiten_keys is applied on the keys in forward()
self.whiten_keys = Whiten(whitening_limit=1.1,
prob=1.0 if __name__ == "__main__" else 0.1,
grad_scale=0.0025)
self.whiten_keys = Whiten(num_groups=num_heads,
whitening_limit=1.1,
prob=(0.025, 0.25),
grad_scale=0.025)
self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
channel_dim=-1, max_abs=5.0)
self.in_max_eig = MaxEig(3 * embed_dim // 2,
channel_dim=-1)
self.out_proj = ScaledLinear(
embed_dim // 2, embed_dim, bias=True, initial_scale=0.05
)
@ -980,10 +859,10 @@ class RelPositionMultiheadAttention(nn.Module):
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
initial_scale=0.05)
# self.whiten_values2 is applied on the values in forward2()
self.whiten_values2 = Whiten(whitening_limit=1.1,
prob=1.0 if __name__ == "__main__" else 0.1,
grad_scale=0.0025)
self.whiten_values2 = Whiten(num_groups=num_heads,
whitening_limit=1.1,
prob=(0.025, 0.25),
grad_scale=0.025)
# linear transformation for positional encoding (projects to a scalar per head,
# which will be added to the score).
@ -1037,7 +916,7 @@ class RelPositionMultiheadAttention(nn.Module):
and S is the sequence length.
"""
x, weights = self.multi_head_attention_forward(
self.in_max_eig(self.in_balancer(self.in_proj(x))),
self.in_balancer(self.in_proj(x)),
self.linear_pos(pos_emb),
self.embed_dim,
self.num_heads,
@ -1155,6 +1034,8 @@ class RelPositionMultiheadAttention(nn.Module):
# self-attention
q, k, v = x.chunk(3, dim=-1)
k = self.whiten_keys(k) # does nothing in the forward pass.
v = self.whiten_values(v) # does nothing in the forward pass.
if attn_mask is not None:
assert (
@ -1207,11 +1088,7 @@ class RelPositionMultiheadAttention(nn.Module):
q = (q * scaling).contiguous().view(seq_len, bsz, num_heads, head_dim)
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
k = self.whiten_keys(k) # does nothing in the forward pass.
v = v.contiguous().view(-1, bsz, num_heads, head_dim)
v = self.whiten_values(v) # does nothing in the forward pass.
v = v.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
@ -1297,7 +1174,6 @@ class RelPositionMultiheadAttention(nn.Module):
head_dim = embed_dim // (num_heads * 2)
# v: (tgt_len, bsz, embed_dim // 2)
v = self.in_proj2(x)
v = v.contiguous().view(-1, bsz, num_heads, head_dim)
v = self.whiten_values2(v) # does nothing in the forward pass.
v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1)

View File

@ -421,6 +421,172 @@ class ActivationBalancer(torch.nn.Module):
return x
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
if x.ndim == 2:
return x.diag()
else:
(batch, dim, dim) = x.shape
x = x.reshape(batch, dim * dim)
x = x[:, ::dim+1]
assert x.shape == (batch, dim)
return x
def _whitening_metric(x: Tensor,
num_groups: int):
"""
Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
of the centered feature covariance are the same within each group's covariance matrix
and also between groups.
Args:
x: a Tensor of shape (*, num_channels)
num_groups: the number of groups of channels, a number >=1 that divides num_channels
Returns:
Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
greater than 1.0 otherwise.
"""
assert x.dtype != torch.float16
x = x.reshape(-1, x.shape[-1])
(num_frames, num_channels) = x.shape
assert num_channels % num_groups == 0
channels_per_group = num_channels // num_groups
x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
# x now has shape (num_groups, num_frames, channels_per_group)
# subtract the mean so we use the centered, not uncentered, covariance.
# My experience has been that when we "mess with the gradients" like this,
# it's better not do anything that tries to move the mean around, because
# that can easily cause instability.
x = x - x.mean(dim=1, keepdim=True)
# x_covar: (num_groups, channels_per_group, channels_per_group)
x_covar = torch.matmul(x.transpose(1, 2), x)
x_covar_mean_diag = _diag(x_covar).mean()
# the following expression is what we'd get if we took the matrix product
# of each covariance and measured the mean of its trace, i.e.
# the same as _diag(torch.matmul(x_covar, x_covar)).mean().
x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group)
# this metric will be >= 1.0; the larger it is, the less 'white' the data was.
metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
return metric
class WhiteningPenaltyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
x: Tensor,
num_groups: int,
whitening_limit: float,
grad_scale: float) -> Tensor:
ctx.save_for_backward(x)
ctx.num_groups = num_groups
ctx.whitening_limit = whitening_limit
ctx.grad_scale = grad_scale
return x
@staticmethod
def backward(ctx,
x_grad: Tensor):
x_orig, = ctx.saved_tensors
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True
metric = _whitening_metric(x_detached, ctx.num_groups)
if random.random() < 0.005 or __name__ == "__main__":
logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}")
(metric - ctx.whitening_limit).relu().backward()
penalty_grad = x_detached.grad
scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() /
(penalty_grad.norm() + 1.0e-20))
penalty_grad = penalty_grad * scale
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
class Whiten(nn.Module):
def __init__(
self,
num_groups: int,
whitening_limit: float,
prob: Union[float, Tuple[float,float]],
grad_scale: float):
"""
Args:
num_groups: the number of groups to divide the channel dim into before
whitening. We will attempt to make the feature covariance
within each group, after mean subtraction, as "white" as possible,
while having the same trace across all groups.
whitening_limit: a value greater than 1.0, that dictates how much
freedom we have to violate the constraints. 1.0 would mean perfectly
white, with exactly the same trace across groups; larger values
give more freedom. E.g. 2.0.
prob: the probability with which we apply the gradient modification
(also affects the grad scale). May be supplied as a float,
or as a pair (min_prob, max_prob)
grad_scale: determines the scale on the gradient term from this object,
relative to the rest of the gradient on the attention weights.
E.g. 0.02 (you may want to use smaller values than this if prob is large)
"""
super(Whiten, self).__init__()
assert num_groups >= 1
assert whitening_limit >= 1
assert grad_scale >= 0
self.num_groups = num_groups
self.whitening_limit = whitening_limit
if isinstance(prob, float):
assert 0 < prob <= 1
self.prob = prob
else:
(self.min_prob, self.max_prob) = prob
assert 0 < self.min_prob < self.max_prob <= 1
self.prob = self.max_prob
self.grad_scale = grad_scale
def forward(self,
x: Tensor) -> Tensor:
"""
In the forward pass, this function just returns the input unmodified.
In the backward pass, it will modify the gradients to ensure that the
distribution in each group has close to (lambda times I) as the covariance
after mean subtraction, with the same lambda across groups.
For whitening_limit > 1, there will be more freedom to violate this
constraint.
Args:
x: the input of shape (*, num_channels)
Returns:
x, unmodified. You should make sure
you use the returned value, or the graph will be freed
and nothing will happen in backprop.
"""
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
return x
else:
if hasattr(self, 'min_prob') and random.random() < 0.25:
# occasionally switch between min_prob and max_prob, based on whether
# we are above or below the threshold.
if _whitening_metric(x, self.num_groups) > self.whitening_limit:
# there would be a change to the grad.
self.prob = self.max_prob
else:
self.prob = self.min_prob
return WhiteningPenaltyFunction.apply(x,
self.num_groups,
self.whitening_limit,
self.grad_scale)
class MaxEig(torch.nn.Module):
"""
Modifies the backpropped derivatives of a function to try to discourage
@ -632,6 +798,35 @@ def _test_max_eig():
assert not torch.allclose(x.grad, y_grad)
def _test_whiten():
for proportion in [0.1, 0.5, 10.0]:
logging.info(f"_test_whiten(): proportion = {proportion}")
x = torch.randn(100, 128)
direction = torch.randn(128)
coeffs = torch.randn(100, 1)
x += proportion * direction * coeffs
x.requires_grad = True
num_channels = 128
m = Whiten(1, # num_groups
5.0, # whitening_limit,
prob=1.0,
grad_scale=0.1) # grad_scale
for _ in range(4):
y = m(x)
y_grad = torch.randn_like(x)
y.backward(gradient=y_grad)
if proportion < 0.2:
assert torch.allclose(x.grad, y_grad)
elif proportion > 1.0:
assert not torch.allclose(x.grad, y_grad)
def _test_activation_balancer_sign():
probs = torch.arange(0, 1, 0.01)
@ -714,6 +909,7 @@ if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_whiten()
_test_max_eig()
_test_activation_balancer_sign()
_test_activation_balancer_magnitude()