Implement whitening of values in conformer.

This commit is contained in:
Daniel Povey 2022-10-15 15:27:05 +08:00
parent 125e1b167c
commit 91840faa97

View File

@ -801,124 +801,126 @@ class RelPositionalEncoding(torch.nn.Module):
return self.dropout(pos_emb) return self.dropout(pos_emb)
class EntropyPenaltyFunction(torch.autograd.Function): def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
if x.ndim == 2:
return x.diag()
else:
(batch, dim, dim) = x.shape
x = x.reshape(batch, dim * dim)
x = x[:, ::dim+1]
assert x.shape == (batch, dim)
return x
class WhiteningPenaltyFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, def forward(ctx,
attn_weights: Tensor, x: Tensor,
num_heads: int, whitening_limit: float,
entropy_limit: float,
grad_scale: float) -> Tensor: grad_scale: float) -> Tensor:
ctx.save_for_backward(attn_weights) ctx.save_for_backward(x)
ctx.num_heads = num_heads ctx.whitening_limit = whitening_limit
ctx.entropy_limit = entropy_limit
ctx.grad_scale = grad_scale ctx.grad_scale = grad_scale
return attn_weights return x
@staticmethod @staticmethod
def backward(ctx, def backward(ctx,
attn_weights_grad: Tensor): x_grad: Tensor):
attn_weights, = ctx.saved_tensors x_orig, = ctx.saved_tensors
num_heads = ctx.num_heads
entropy_limit = ctx.entropy_limit
grad_scale = ctx.grad_scale
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
attn_weights_orig = attn_weights.to(torch.float32).detach() x_detached = x_orig.to(torch.float32).detach()
attn_weights_orig.requires_grad = True x_detached.requires_grad = True
bsz = attn_weights_orig.shape[0] // num_heads assert x_detached.ndim >= 3
seq_len = attn_weights_orig.shape[2] x = x_detached.reshape(-1, x_detached.shape[-2],
attn_weights = attn_weights_orig.reshape(bsz, num_heads, x_detached.shape[-1]).transpose(0, 1)
seq_len, seq_len) (num_groups, num_frames, channels_per_group) = x.shape
grad_norms = attn_weights_grad.detach().reshape( # subtract the mean so we use the centered, not uncentered, covariance.
bsz, num_heads, seq_len * seq_len).norm(dim=(0,2)) # My experience has been that when we "mess with the gradients" like this,
# it's better not do anything that tries to move the mean around, because
# that can easily cause instability.
x = x - x.mean(dim=1, keepdim=True)
# x_covar: (num_groups, channels_per_group, channels_per_group)
x_covar = torch.matmul(x.transpose(1, 2), x)
# normalize x_covar so that its average diagonal element is 1.
x_covar = x_covar / (_diag(x_covar).mean() + 1.0e-20)
# x_covar_sq: (num_groups, channels_per_group, channels_per_group).
# if the normalized x_covar were just `num_groups` copies of the
# identity matrix, x_covar_sq will have the same value. But
# in general, it will be larger than that.
x_covar_sq = torch.matmul(x_covar, x_covar)
metric = _diag(x_covar_sq).mean()
entropy = ((attn_weights + 1.0e-20).log() * attn_weights).sum(dim=-1)
# entropy: (bsz, num_heads, seq_len)
entropy = -entropy.mean(dim=(0,2))
# entropy: (num_heads,)
assert entropy.shape == (num_heads,)
excess_entropy = (entropy - entropy_limit).relu()
above_cutoff = (entropy > 0) # tensor of shape (num_heads,)
small_grad_norm = (grad_norms < grad_norms.mean())
will_penalize = torch.logical_and(above_cutoff, small_grad_norm)
if random.random() < 0.005 or __name__ == "__main__": if random.random() < 0.005 or __name__ == "__main__":
logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}") logging.info(f"Whitening: num_groups={num_groups}, channels_per_group={channels_per_group}, "
will_penalize_sum = will_penalize.to(torch.float32).sum().item() f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}")
if will_penalize_sum == 0:
# grad would be 0. I'm guessing that checking this, and (metric - ctx.whitening_limit).relu().backward()
# incurring a CUDA sync, may save time relative to doing the penalty_grad = x_detached.grad
# backprop of the entropy, but I'm not sure. scale = ctx.grad_scale * (x.to(torch.float32).norm() /
return attn_weights_grad, None, None, None (penalty_grad.norm() + 1.0e-20))
# Treat `excess_entropy` as a loss, to be minimized. penalty_grad = penalty_grad * scale
excess_entropy.backward(gradient=will_penalize.to(torch.float32)) return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
entropy_grad = attn_weights_orig.grad
scale = ((grad_scale * will_penalize_sum / num_heads) *
(attn_weights_grad.to(torch.float32).norm() /
(entropy_grad.norm() + 1.0e-20)))
entropy_grad = entropy_grad * scale
return attn_weights_grad + entropy_grad.to(attn_weights_grad.dtype), None, None, None
class Whiten(nn.Module):
class EntropyPenalty(nn.Module):
def __init__( def __init__(
self, self,
num_heads: float, whitening_limit: float,
entropy_delta: float,
prob: float, prob: float,
grad_scale: float): grad_scale: float):
""" """
Args: Args:
num_heads: the number of attention heads in the self-attention module that num_groups: the number of groups to divide the input into before
this is attached to. whitening it. We will attempt to make the feature covariance
entropy_delta: the delta from the maximum entropy, that we aim to within each group, after mean subtraction, as "white" as possible
decrease the entropy to if it is above. So the maximum entropy while having the same trace across all groups.
should be max(log(seq_len) - entropy_cutoff, 0.5 * log(seq_len)); whitening_limit: a value greater than 1.0, that dictates how much
the second term is to make sure the limit never becomes tiny or freedom we have to violate the constraints. 1.0 would mean perfectly
negative in the case of short sequences. white, with exactly the same trace across groups; larger values
prob: the probability with which we apply this object. give more freedom. E.g. 2.0.
prob: the probability with which we apply this object (also affects
grad scale). e.g. 0.25
grad_scale: determines the scale on the gradient term from this object, grad_scale: determines the scale on the gradient term from this object,
relative to the rest of the gradient on the attention weights; relative to the rest of the gradient on the attention weights;
will be divided by `prob`. will be divided by `prob`. e.g. 0.005
""" """
super(EntropyPenalty, self).__init__() super(Whiten, self).__init__()
self.num_heads = num_heads assert whitening_limit >= 1
self.entropy_delta = entropy_delta assert 0 < prob <= 1
assert grad_scale >= 0
self.whitening_limit = whitening_limit
self.prob = prob self.prob = prob
self.grad_scale = grad_scale self.grad_scale = grad_scale
def forward(self, def forward(self,
attn_weights: Tensor) -> Tensor: x: Tensor) -> Tensor:
""" """
In the forward pass, this function just returns the attention weights. In the forward pass, this function just returns the input unmodified.
In the backward pass, it will modify the gradients to ensure that the In the backward pass, it will modify the gradients to ensure that the
entropy of the attention heads is not too large. (We have noticed distribution in each group has close to (lambda times I) as the covariance
that too-large/almost-maximal entropy in the attention distribution after mean subtraction, with the same lambda across groups.
is associated with heads that are not doing anything useful. For whitening_limit > 1, there will be more freedom to violate this
constraint.
Args: Args:
attn_weights: the attention weights, after the log, with shape x: the input of shape (*, num_groups, channels_per_group)
(batch_size * num_heads, seq_len, seq_len), satisfying:
attn_weights.sum(dim=-1) == 1.
Returns: Returns:
the attn_weights, without any change. You should make sure x, unmodified. You should make sure
you use the returned attention weights, or the graph will be freed you use the returned value, or the graph will be freed
and nothing will happen in backprop. and nothing will happen in backprop.
""" """
if not attn_weights.requires_grad or random.random() > self.prob: if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
return attn_weights return x
else: else:
seq_len = attn_weights.shape[2] return WhiteningPenaltyFunction.apply(x,
max_entropy = math.log(seq_len) self.whitening_limit,
entropy_limit = max(max_entropy - self.entropy_delta, self.grad_scale / self.prob)
0.5 * max_entropy)
return EntropyPenaltyFunction.apply(attn_weights,
self.num_heads,
entropy_limit,
self.grad_scale / self.prob)
@ -955,6 +957,13 @@ class RelPositionMultiheadAttention(nn.Module):
), "embed_dim//2 must be divisible by num_heads" ), "embed_dim//2 must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True) self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
# self.whiten is applied on the values in forward()
self.whiten = Whiten(whitening_limit=2.0,
prob=1.0 if __name__ == "__main__" else 0.1,
grad_scale=0.0025)
self.in_balancer = ActivationBalancer(3 * embed_dim // 2, self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
channel_dim=-1, max_abs=5.0) channel_dim=-1, max_abs=5.0)
self.in_max_eig = MaxEig(3 * embed_dim // 2, self.in_max_eig = MaxEig(3 * embed_dim // 2,
@ -966,17 +975,15 @@ class RelPositionMultiheadAttention(nn.Module):
self.in_proj2 = nn.Linear(embed_dim, embed_dim // 2, bias=False) self.in_proj2 = nn.Linear(embed_dim, embed_dim // 2, bias=False)
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True, self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
initial_scale=0.05) initial_scale=0.05)
# self.whiten is applied on the values in forward2()
self.whiten2 = Whiten(whitening_limit=2.0,
prob=1.0 if __name__ == "__main__" else 0.1,
grad_scale=0.0025)
# linear transformation for positional encoding (projects to a scalar per head, # linear transformation for positional encoding (projects to a scalar per head,
# which will be added to the score). # which will be added to the score).
self.linear_pos = ScaledLinear(embed_dim, num_heads, initial_scale=0.05) self.linear_pos = ScaledLinear(embed_dim, num_heads, initial_scale=0.05)
self.entropy_penalty = EntropyPenalty(num_heads,
entropy_delta=1.5,
prob=1.0 if __name__ == "__main__" else 0.2,
grad_scale=0.01)
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
# linear transformation for positional encoding. # linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, num_heads, bias=False) self.linear_pos = nn.Linear(embed_dim, num_heads, bias=False)
@ -1196,7 +1203,9 @@ class RelPositionMultiheadAttention(nn.Module):
q = (q * scaling).contiguous().view(seq_len, bsz, num_heads, head_dim) q = (q * scaling).contiguous().view(seq_len, bsz, num_heads, head_dim)
k = k.contiguous().view(-1, bsz, num_heads, head_dim) k = k.contiguous().view(-1, bsz, num_heads, head_dim)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) v = v.contiguous().view(-1, bsz, num_heads, head_dim)
v = self.whiten(v) # does nothing in the forward pass.
v = v.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if key_padding_mask is not None: if key_padding_mask is not None:
@ -1278,14 +1287,15 @@ class RelPositionMultiheadAttention(nn.Module):
Returns: Returns:
output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) output of the same shape as x, i.e. (seq_len, batch_size, embed_dim)
""" """
attn_weights = self.entropy_penalty(attn_weights)
num_heads = self.num_heads num_heads = self.num_heads
(seq_len, bsz, embed_dim) = x.shape (seq_len, bsz, embed_dim) = x.shape
head_dim = embed_dim // (num_heads * 2) head_dim = embed_dim // (num_heads * 2)
# v: (tgt_len, bsz, embed_dim // 2) # v: (tgt_len, bsz, embed_dim // 2)
v = self.in_proj2(x) v = self.in_proj2(x)
v = v.contiguous().view(-1, bsz, num_heads, head_dim)
v = self.whiten2(v) # does nothing in the forward pass.
v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1) v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
# now v: (bsz * num_heads, seq_len, head_dim) # now v: (bsz * num_heads, seq_len, head_dim)
attn_output = torch.bmm(attn_weights, v) attn_output = torch.bmm(attn_weights, v)