Implement whitening of values in conformer.

This commit is contained in:
Daniel Povey 2022-10-15 15:27:05 +08:00
parent 125e1b167c
commit 91840faa97

View File

@ -801,124 +801,126 @@ class RelPositionalEncoding(torch.nn.Module):
return self.dropout(pos_emb)
class EntropyPenaltyFunction(torch.autograd.Function):
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
if x.ndim == 2:
return x.diag()
else:
(batch, dim, dim) = x.shape
x = x.reshape(batch, dim * dim)
x = x[:, ::dim+1]
assert x.shape == (batch, dim)
return x
class WhiteningPenaltyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
attn_weights: Tensor,
num_heads: int,
entropy_limit: float,
x: Tensor,
whitening_limit: float,
grad_scale: float) -> Tensor:
ctx.save_for_backward(attn_weights)
ctx.num_heads = num_heads
ctx.entropy_limit = entropy_limit
ctx.save_for_backward(x)
ctx.whitening_limit = whitening_limit
ctx.grad_scale = grad_scale
return attn_weights
return x
@staticmethod
def backward(ctx,
attn_weights_grad: Tensor):
attn_weights, = ctx.saved_tensors
num_heads = ctx.num_heads
entropy_limit = ctx.entropy_limit
grad_scale = ctx.grad_scale
x_grad: Tensor):
x_orig, = ctx.saved_tensors
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
attn_weights_orig = attn_weights.to(torch.float32).detach()
attn_weights_orig.requires_grad = True
bsz = attn_weights_orig.shape[0] // num_heads
seq_len = attn_weights_orig.shape[2]
attn_weights = attn_weights_orig.reshape(bsz, num_heads,
seq_len, seq_len)
x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True
assert x_detached.ndim >= 3
x = x_detached.reshape(-1, x_detached.shape[-2],
x_detached.shape[-1]).transpose(0, 1)
(num_groups, num_frames, channels_per_group) = x.shape
grad_norms = attn_weights_grad.detach().reshape(
bsz, num_heads, seq_len * seq_len).norm(dim=(0,2))
# subtract the mean so we use the centered, not uncentered, covariance.
# My experience has been that when we "mess with the gradients" like this,
# it's better not do anything that tries to move the mean around, because
# that can easily cause instability.
x = x - x.mean(dim=1, keepdim=True)
# x_covar: (num_groups, channels_per_group, channels_per_group)
x_covar = torch.matmul(x.transpose(1, 2), x)
# normalize x_covar so that its average diagonal element is 1.
x_covar = x_covar / (_diag(x_covar).mean() + 1.0e-20)
# x_covar_sq: (num_groups, channels_per_group, channels_per_group).
# if the normalized x_covar were just `num_groups` copies of the
# identity matrix, x_covar_sq will have the same value. But
# in general, it will be larger than that.
x_covar_sq = torch.matmul(x_covar, x_covar)
metric = _diag(x_covar_sq).mean()
entropy = ((attn_weights + 1.0e-20).log() * attn_weights).sum(dim=-1)
# entropy: (bsz, num_heads, seq_len)
entropy = -entropy.mean(dim=(0,2))
# entropy: (num_heads,)
assert entropy.shape == (num_heads,)
excess_entropy = (entropy - entropy_limit).relu()
above_cutoff = (entropy > 0) # tensor of shape (num_heads,)
small_grad_norm = (grad_norms < grad_norms.mean())
will_penalize = torch.logical_and(above_cutoff, small_grad_norm)
if random.random() < 0.005 or __name__ == "__main__":
logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}")
will_penalize_sum = will_penalize.to(torch.float32).sum().item()
if will_penalize_sum == 0:
# grad would be 0. I'm guessing that checking this, and
# incurring a CUDA sync, may save time relative to doing the
# backprop of the entropy, but I'm not sure.
return attn_weights_grad, None, None, None
# Treat `excess_entropy` as a loss, to be minimized.
excess_entropy.backward(gradient=will_penalize.to(torch.float32))
entropy_grad = attn_weights_orig.grad
scale = ((grad_scale * will_penalize_sum / num_heads) *
(attn_weights_grad.to(torch.float32).norm() /
(entropy_grad.norm() + 1.0e-20)))
entropy_grad = entropy_grad * scale
return attn_weights_grad + entropy_grad.to(attn_weights_grad.dtype), None, None, None
logging.info(f"Whitening: num_groups={num_groups}, channels_per_group={channels_per_group}, "
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}")
(metric - ctx.whitening_limit).relu().backward()
penalty_grad = x_detached.grad
scale = ctx.grad_scale * (x.to(torch.float32).norm() /
(penalty_grad.norm() + 1.0e-20))
penalty_grad = penalty_grad * scale
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
class EntropyPenalty(nn.Module):
class Whiten(nn.Module):
def __init__(
self,
num_heads: float,
entropy_delta: float,
whitening_limit: float,
prob: float,
grad_scale: float):
"""
Args:
num_heads: the number of attention heads in the self-attention module that
this is attached to.
entropy_delta: the delta from the maximum entropy, that we aim to
decrease the entropy to if it is above. So the maximum entropy
should be max(log(seq_len) - entropy_cutoff, 0.5 * log(seq_len));
the second term is to make sure the limit never becomes tiny or
negative in the case of short sequences.
prob: the probability with which we apply this object.
num_groups: the number of groups to divide the input into before
whitening it. We will attempt to make the feature covariance
within each group, after mean subtraction, as "white" as possible
while having the same trace across all groups.
whitening_limit: a value greater than 1.0, that dictates how much
freedom we have to violate the constraints. 1.0 would mean perfectly
white, with exactly the same trace across groups; larger values
give more freedom. E.g. 2.0.
prob: the probability with which we apply this object (also affects
grad scale). e.g. 0.25
grad_scale: determines the scale on the gradient term from this object,
relative to the rest of the gradient on the attention weights;
will be divided by `prob`.
will be divided by `prob`. e.g. 0.005
"""
super(EntropyPenalty, self).__init__()
self.num_heads = num_heads
self.entropy_delta = entropy_delta
super(Whiten, self).__init__()
assert whitening_limit >= 1
assert 0 < prob <= 1
assert grad_scale >= 0
self.whitening_limit = whitening_limit
self.prob = prob
self.grad_scale = grad_scale
def forward(self,
attn_weights: Tensor) -> Tensor:
x: Tensor) -> Tensor:
"""
In the forward pass, this function just returns the attention weights.
In the forward pass, this function just returns the input unmodified.
In the backward pass, it will modify the gradients to ensure that the
entropy of the attention heads is not too large. (We have noticed
that too-large/almost-maximal entropy in the attention distribution
is associated with heads that are not doing anything useful.
distribution in each group has close to (lambda times I) as the covariance
after mean subtraction, with the same lambda across groups.
For whitening_limit > 1, there will be more freedom to violate this
constraint.
Args:
attn_weights: the attention weights, after the log, with shape
(batch_size * num_heads, seq_len, seq_len), satisfying:
attn_weights.sum(dim=-1) == 1.
x: the input of shape (*, num_groups, channels_per_group)
Returns:
the attn_weights, without any change. You should make sure
you use the returned attention weights, or the graph will be freed
and nothing will happen in backprop.
x, unmodified. You should make sure
you use the returned value, or the graph will be freed
and nothing will happen in backprop.
"""
if not attn_weights.requires_grad or random.random() > self.prob:
return attn_weights
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
return x
else:
seq_len = attn_weights.shape[2]
max_entropy = math.log(seq_len)
entropy_limit = max(max_entropy - self.entropy_delta,
0.5 * max_entropy)
return EntropyPenaltyFunction.apply(attn_weights,
self.num_heads,
entropy_limit,
self.grad_scale / self.prob)
return WhiteningPenaltyFunction.apply(x,
self.whitening_limit,
self.grad_scale / self.prob)
@ -955,6 +957,13 @@ class RelPositionMultiheadAttention(nn.Module):
), "embed_dim//2 must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
# self.whiten is applied on the values in forward()
self.whiten = Whiten(whitening_limit=2.0,
prob=1.0 if __name__ == "__main__" else 0.1,
grad_scale=0.0025)
self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
channel_dim=-1, max_abs=5.0)
self.in_max_eig = MaxEig(3 * embed_dim // 2,
@ -966,17 +975,15 @@ class RelPositionMultiheadAttention(nn.Module):
self.in_proj2 = nn.Linear(embed_dim, embed_dim // 2, bias=False)
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
initial_scale=0.05)
# self.whiten is applied on the values in forward2()
self.whiten2 = Whiten(whitening_limit=2.0,
prob=1.0 if __name__ == "__main__" else 0.1,
grad_scale=0.0025)
# linear transformation for positional encoding (projects to a scalar per head,
# which will be added to the score).
self.linear_pos = ScaledLinear(embed_dim, num_heads, initial_scale=0.05)
self.entropy_penalty = EntropyPenalty(num_heads,
entropy_delta=1.5,
prob=1.0 if __name__ == "__main__" else 0.2,
grad_scale=0.01)
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, num_heads, bias=False)
@ -1196,7 +1203,9 @@ class RelPositionMultiheadAttention(nn.Module):
q = (q * scaling).contiguous().view(seq_len, bsz, num_heads, head_dim)
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
v = v.contiguous().view(-1, bsz, num_heads, head_dim)
v = self.whiten(v) # does nothing in the forward pass.
v = v.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if key_padding_mask is not None:
@ -1278,14 +1287,15 @@ class RelPositionMultiheadAttention(nn.Module):
Returns:
output of the same shape as x, i.e. (seq_len, batch_size, embed_dim)
"""
attn_weights = self.entropy_penalty(attn_weights)
num_heads = self.num_heads
(seq_len, bsz, embed_dim) = x.shape
head_dim = embed_dim // (num_heads * 2)
# v: (tgt_len, bsz, embed_dim // 2)
v = self.in_proj2(x)
v = v.contiguous().view(-1, bsz, num_heads, head_dim)
v = self.whiten2(v) # does nothing in the forward pass.
v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
# now v: (bsz * num_heads, seq_len, head_dim)
attn_output = torch.bmm(attn_weights, v)