mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement whitening of values in conformer.
This commit is contained in:
parent
125e1b167c
commit
91840faa97
@ -801,124 +801,126 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
return self.dropout(pos_emb)
|
||||
|
||||
|
||||
class EntropyPenaltyFunction(torch.autograd.Function):
|
||||
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
||||
if x.ndim == 2:
|
||||
return x.diag()
|
||||
else:
|
||||
(batch, dim, dim) = x.shape
|
||||
x = x.reshape(batch, dim * dim)
|
||||
x = x[:, ::dim+1]
|
||||
assert x.shape == (batch, dim)
|
||||
return x
|
||||
|
||||
class WhiteningPenaltyFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
attn_weights: Tensor,
|
||||
num_heads: int,
|
||||
entropy_limit: float,
|
||||
x: Tensor,
|
||||
whitening_limit: float,
|
||||
grad_scale: float) -> Tensor:
|
||||
ctx.save_for_backward(attn_weights)
|
||||
ctx.num_heads = num_heads
|
||||
ctx.entropy_limit = entropy_limit
|
||||
ctx.save_for_backward(x)
|
||||
ctx.whitening_limit = whitening_limit
|
||||
ctx.grad_scale = grad_scale
|
||||
return attn_weights
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx,
|
||||
attn_weights_grad: Tensor):
|
||||
attn_weights, = ctx.saved_tensors
|
||||
num_heads = ctx.num_heads
|
||||
entropy_limit = ctx.entropy_limit
|
||||
grad_scale = ctx.grad_scale
|
||||
x_grad: Tensor):
|
||||
x_orig, = ctx.saved_tensors
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
attn_weights_orig = attn_weights.to(torch.float32).detach()
|
||||
attn_weights_orig.requires_grad = True
|
||||
bsz = attn_weights_orig.shape[0] // num_heads
|
||||
seq_len = attn_weights_orig.shape[2]
|
||||
attn_weights = attn_weights_orig.reshape(bsz, num_heads,
|
||||
seq_len, seq_len)
|
||||
x_detached = x_orig.to(torch.float32).detach()
|
||||
x_detached.requires_grad = True
|
||||
assert x_detached.ndim >= 3
|
||||
x = x_detached.reshape(-1, x_detached.shape[-2],
|
||||
x_detached.shape[-1]).transpose(0, 1)
|
||||
(num_groups, num_frames, channels_per_group) = x.shape
|
||||
|
||||
grad_norms = attn_weights_grad.detach().reshape(
|
||||
bsz, num_heads, seq_len * seq_len).norm(dim=(0,2))
|
||||
# subtract the mean so we use the centered, not uncentered, covariance.
|
||||
# My experience has been that when we "mess with the gradients" like this,
|
||||
# it's better not do anything that tries to move the mean around, because
|
||||
# that can easily cause instability.
|
||||
x = x - x.mean(dim=1, keepdim=True)
|
||||
|
||||
# x_covar: (num_groups, channels_per_group, channels_per_group)
|
||||
x_covar = torch.matmul(x.transpose(1, 2), x)
|
||||
# normalize x_covar so that its average diagonal element is 1.
|
||||
x_covar = x_covar / (_diag(x_covar).mean() + 1.0e-20)
|
||||
# x_covar_sq: (num_groups, channels_per_group, channels_per_group).
|
||||
# if the normalized x_covar were just `num_groups` copies of the
|
||||
# identity matrix, x_covar_sq will have the same value. But
|
||||
# in general, it will be larger than that.
|
||||
x_covar_sq = torch.matmul(x_covar, x_covar)
|
||||
|
||||
metric = _diag(x_covar_sq).mean()
|
||||
|
||||
entropy = ((attn_weights + 1.0e-20).log() * attn_weights).sum(dim=-1)
|
||||
# entropy: (bsz, num_heads, seq_len)
|
||||
entropy = -entropy.mean(dim=(0,2))
|
||||
# entropy: (num_heads,)
|
||||
assert entropy.shape == (num_heads,)
|
||||
excess_entropy = (entropy - entropy_limit).relu()
|
||||
above_cutoff = (entropy > 0) # tensor of shape (num_heads,)
|
||||
small_grad_norm = (grad_norms < grad_norms.mean())
|
||||
will_penalize = torch.logical_and(above_cutoff, small_grad_norm)
|
||||
if random.random() < 0.005 or __name__ == "__main__":
|
||||
logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}")
|
||||
will_penalize_sum = will_penalize.to(torch.float32).sum().item()
|
||||
if will_penalize_sum == 0:
|
||||
# grad would be 0. I'm guessing that checking this, and
|
||||
# incurring a CUDA sync, may save time relative to doing the
|
||||
# backprop of the entropy, but I'm not sure.
|
||||
return attn_weights_grad, None, None, None
|
||||
# Treat `excess_entropy` as a loss, to be minimized.
|
||||
excess_entropy.backward(gradient=will_penalize.to(torch.float32))
|
||||
entropy_grad = attn_weights_orig.grad
|
||||
scale = ((grad_scale * will_penalize_sum / num_heads) *
|
||||
(attn_weights_grad.to(torch.float32).norm() /
|
||||
(entropy_grad.norm() + 1.0e-20)))
|
||||
entropy_grad = entropy_grad * scale
|
||||
return attn_weights_grad + entropy_grad.to(attn_weights_grad.dtype), None, None, None
|
||||
logging.info(f"Whitening: num_groups={num_groups}, channels_per_group={channels_per_group}, "
|
||||
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}")
|
||||
|
||||
(metric - ctx.whitening_limit).relu().backward()
|
||||
penalty_grad = x_detached.grad
|
||||
scale = ctx.grad_scale * (x.to(torch.float32).norm() /
|
||||
(penalty_grad.norm() + 1.0e-20))
|
||||
penalty_grad = penalty_grad * scale
|
||||
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class EntropyPenalty(nn.Module):
|
||||
class Whiten(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: float,
|
||||
entropy_delta: float,
|
||||
whitening_limit: float,
|
||||
prob: float,
|
||||
grad_scale: float):
|
||||
"""
|
||||
Args:
|
||||
num_heads: the number of attention heads in the self-attention module that
|
||||
this is attached to.
|
||||
entropy_delta: the delta from the maximum entropy, that we aim to
|
||||
decrease the entropy to if it is above. So the maximum entropy
|
||||
should be max(log(seq_len) - entropy_cutoff, 0.5 * log(seq_len));
|
||||
the second term is to make sure the limit never becomes tiny or
|
||||
negative in the case of short sequences.
|
||||
prob: the probability with which we apply this object.
|
||||
num_groups: the number of groups to divide the input into before
|
||||
whitening it. We will attempt to make the feature covariance
|
||||
within each group, after mean subtraction, as "white" as possible
|
||||
while having the same trace across all groups.
|
||||
whitening_limit: a value greater than 1.0, that dictates how much
|
||||
freedom we have to violate the constraints. 1.0 would mean perfectly
|
||||
white, with exactly the same trace across groups; larger values
|
||||
give more freedom. E.g. 2.0.
|
||||
prob: the probability with which we apply this object (also affects
|
||||
grad scale). e.g. 0.25
|
||||
grad_scale: determines the scale on the gradient term from this object,
|
||||
relative to the rest of the gradient on the attention weights;
|
||||
will be divided by `prob`.
|
||||
will be divided by `prob`. e.g. 0.005
|
||||
"""
|
||||
super(EntropyPenalty, self).__init__()
|
||||
self.num_heads = num_heads
|
||||
self.entropy_delta = entropy_delta
|
||||
super(Whiten, self).__init__()
|
||||
assert whitening_limit >= 1
|
||||
assert 0 < prob <= 1
|
||||
assert grad_scale >= 0
|
||||
self.whitening_limit = whitening_limit
|
||||
self.prob = prob
|
||||
self.grad_scale = grad_scale
|
||||
|
||||
def forward(self,
|
||||
attn_weights: Tensor) -> Tensor:
|
||||
x: Tensor) -> Tensor:
|
||||
"""
|
||||
In the forward pass, this function just returns the attention weights.
|
||||
In the forward pass, this function just returns the input unmodified.
|
||||
In the backward pass, it will modify the gradients to ensure that the
|
||||
entropy of the attention heads is not too large. (We have noticed
|
||||
that too-large/almost-maximal entropy in the attention distribution
|
||||
is associated with heads that are not doing anything useful.
|
||||
distribution in each group has close to (lambda times I) as the covariance
|
||||
after mean subtraction, with the same lambda across groups.
|
||||
For whitening_limit > 1, there will be more freedom to violate this
|
||||
constraint.
|
||||
|
||||
Args:
|
||||
attn_weights: the attention weights, after the log, with shape
|
||||
(batch_size * num_heads, seq_len, seq_len), satisfying:
|
||||
attn_weights.sum(dim=-1) == 1.
|
||||
x: the input of shape (*, num_groups, channels_per_group)
|
||||
|
||||
Returns:
|
||||
the attn_weights, without any change. You should make sure
|
||||
you use the returned attention weights, or the graph will be freed
|
||||
and nothing will happen in backprop.
|
||||
x, unmodified. You should make sure
|
||||
you use the returned value, or the graph will be freed
|
||||
and nothing will happen in backprop.
|
||||
"""
|
||||
if not attn_weights.requires_grad or random.random() > self.prob:
|
||||
return attn_weights
|
||||
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
|
||||
return x
|
||||
else:
|
||||
seq_len = attn_weights.shape[2]
|
||||
max_entropy = math.log(seq_len)
|
||||
entropy_limit = max(max_entropy - self.entropy_delta,
|
||||
0.5 * max_entropy)
|
||||
return EntropyPenaltyFunction.apply(attn_weights,
|
||||
self.num_heads,
|
||||
entropy_limit,
|
||||
self.grad_scale / self.prob)
|
||||
return WhiteningPenaltyFunction.apply(x,
|
||||
self.whitening_limit,
|
||||
self.grad_scale / self.prob)
|
||||
|
||||
|
||||
|
||||
@ -955,6 +957,13 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
), "embed_dim//2 must be divisible by num_heads"
|
||||
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
|
||||
|
||||
# self.whiten is applied on the values in forward()
|
||||
self.whiten = Whiten(whitening_limit=2.0,
|
||||
prob=1.0 if __name__ == "__main__" else 0.1,
|
||||
grad_scale=0.0025)
|
||||
|
||||
|
||||
self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
|
||||
channel_dim=-1, max_abs=5.0)
|
||||
self.in_max_eig = MaxEig(3 * embed_dim // 2,
|
||||
@ -966,17 +975,15 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.in_proj2 = nn.Linear(embed_dim, embed_dim // 2, bias=False)
|
||||
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
|
||||
initial_scale=0.05)
|
||||
# self.whiten is applied on the values in forward2()
|
||||
self.whiten2 = Whiten(whitening_limit=2.0,
|
||||
prob=1.0 if __name__ == "__main__" else 0.1,
|
||||
grad_scale=0.0025)
|
||||
|
||||
|
||||
# linear transformation for positional encoding (projects to a scalar per head,
|
||||
# which will be added to the score).
|
||||
self.linear_pos = ScaledLinear(embed_dim, num_heads, initial_scale=0.05)
|
||||
self.entropy_penalty = EntropyPenalty(num_heads,
|
||||
entropy_delta=1.5,
|
||||
prob=1.0 if __name__ == "__main__" else 0.2,
|
||||
grad_scale=0.01)
|
||||
|
||||
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
|
||||
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = nn.Linear(embed_dim, num_heads, bias=False)
|
||||
@ -1196,7 +1203,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
q = (q * scaling).contiguous().view(seq_len, bsz, num_heads, head_dim)
|
||||
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
|
||||
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
v = v.contiguous().view(-1, bsz, num_heads, head_dim)
|
||||
v = self.whiten(v) # does nothing in the forward pass.
|
||||
v = v.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
|
||||
if key_padding_mask is not None:
|
||||
@ -1278,14 +1287,15 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
Returns:
|
||||
output of the same shape as x, i.e. (seq_len, batch_size, embed_dim)
|
||||
"""
|
||||
attn_weights = self.entropy_penalty(attn_weights)
|
||||
|
||||
num_heads = self.num_heads
|
||||
(seq_len, bsz, embed_dim) = x.shape
|
||||
head_dim = embed_dim // (num_heads * 2)
|
||||
# v: (tgt_len, bsz, embed_dim // 2)
|
||||
v = self.in_proj2(x)
|
||||
v = v.contiguous().view(-1, bsz, num_heads, head_dim)
|
||||
v = self.whiten2(v) # does nothing in the forward pass.
|
||||
v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
# now v: (bsz * num_heads, seq_len, head_dim)
|
||||
attn_output = torch.bmm(attn_weights, v)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user