mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp117' into scaled_adam_exp119
# Conflicts: # egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py
This commit is contained in:
commit
125e1b167c
@ -801,6 +801,128 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
return self.dropout(pos_emb)
|
||||
|
||||
|
||||
class EntropyPenaltyFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
attn_weights: Tensor,
|
||||
num_heads: int,
|
||||
entropy_limit: float,
|
||||
grad_scale: float) -> Tensor:
|
||||
ctx.save_for_backward(attn_weights)
|
||||
ctx.num_heads = num_heads
|
||||
ctx.entropy_limit = entropy_limit
|
||||
ctx.grad_scale = grad_scale
|
||||
return attn_weights
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx,
|
||||
attn_weights_grad: Tensor):
|
||||
attn_weights, = ctx.saved_tensors
|
||||
num_heads = ctx.num_heads
|
||||
entropy_limit = ctx.entropy_limit
|
||||
grad_scale = ctx.grad_scale
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
attn_weights_orig = attn_weights.to(torch.float32).detach()
|
||||
attn_weights_orig.requires_grad = True
|
||||
bsz = attn_weights_orig.shape[0] // num_heads
|
||||
seq_len = attn_weights_orig.shape[2]
|
||||
attn_weights = attn_weights_orig.reshape(bsz, num_heads,
|
||||
seq_len, seq_len)
|
||||
|
||||
grad_norms = attn_weights_grad.detach().reshape(
|
||||
bsz, num_heads, seq_len * seq_len).norm(dim=(0,2))
|
||||
|
||||
entropy = ((attn_weights + 1.0e-20).log() * attn_weights).sum(dim=-1)
|
||||
# entropy: (bsz, num_heads, seq_len)
|
||||
entropy = -entropy.mean(dim=(0,2))
|
||||
# entropy: (num_heads,)
|
||||
assert entropy.shape == (num_heads,)
|
||||
excess_entropy = (entropy - entropy_limit).relu()
|
||||
above_cutoff = (entropy > 0) # tensor of shape (num_heads,)
|
||||
small_grad_norm = (grad_norms < grad_norms.mean())
|
||||
will_penalize = torch.logical_and(above_cutoff, small_grad_norm)
|
||||
if random.random() < 0.005 or __name__ == "__main__":
|
||||
logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}")
|
||||
will_penalize_sum = will_penalize.to(torch.float32).sum().item()
|
||||
if will_penalize_sum == 0:
|
||||
# grad would be 0. I'm guessing that checking this, and
|
||||
# incurring a CUDA sync, may save time relative to doing the
|
||||
# backprop of the entropy, but I'm not sure.
|
||||
return attn_weights_grad, None, None, None
|
||||
# Treat `excess_entropy` as a loss, to be minimized.
|
||||
excess_entropy.backward(gradient=will_penalize.to(torch.float32))
|
||||
entropy_grad = attn_weights_orig.grad
|
||||
scale = ((grad_scale * will_penalize_sum / num_heads) *
|
||||
(attn_weights_grad.to(torch.float32).norm() /
|
||||
(entropy_grad.norm() + 1.0e-20)))
|
||||
entropy_grad = entropy_grad * scale
|
||||
return attn_weights_grad + entropy_grad.to(attn_weights_grad.dtype), None, None, None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class EntropyPenalty(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: float,
|
||||
entropy_delta: float,
|
||||
prob: float,
|
||||
grad_scale: float):
|
||||
"""
|
||||
Args:
|
||||
num_heads: the number of attention heads in the self-attention module that
|
||||
this is attached to.
|
||||
entropy_delta: the delta from the maximum entropy, that we aim to
|
||||
decrease the entropy to if it is above. So the maximum entropy
|
||||
should be max(log(seq_len) - entropy_cutoff, 0.5 * log(seq_len));
|
||||
the second term is to make sure the limit never becomes tiny or
|
||||
negative in the case of short sequences.
|
||||
prob: the probability with which we apply this object.
|
||||
grad_scale: determines the scale on the gradient term from this object,
|
||||
relative to the rest of the gradient on the attention weights;
|
||||
will be divided by `prob`.
|
||||
"""
|
||||
super(EntropyPenalty, self).__init__()
|
||||
self.num_heads = num_heads
|
||||
self.entropy_delta = entropy_delta
|
||||
self.prob = prob
|
||||
self.grad_scale = grad_scale
|
||||
|
||||
def forward(self,
|
||||
attn_weights: Tensor) -> Tensor:
|
||||
"""
|
||||
In the forward pass, this function just returns the attention weights.
|
||||
In the backward pass, it will modify the gradients to ensure that the
|
||||
entropy of the attention heads is not too large. (We have noticed
|
||||
that too-large/almost-maximal entropy in the attention distribution
|
||||
is associated with heads that are not doing anything useful.
|
||||
|
||||
Args:
|
||||
attn_weights: the attention weights, after the log, with shape
|
||||
(batch_size * num_heads, seq_len, seq_len), satisfying:
|
||||
attn_weights.sum(dim=-1) == 1.
|
||||
Returns:
|
||||
the attn_weights, without any change. You should make sure
|
||||
you use the returned attention weights, or the graph will be freed
|
||||
and nothing will happen in backprop.
|
||||
"""
|
||||
if not attn_weights.requires_grad or random.random() > self.prob:
|
||||
return attn_weights
|
||||
else:
|
||||
seq_len = attn_weights.shape[2]
|
||||
max_entropy = math.log(seq_len)
|
||||
entropy_limit = max(max_entropy - self.entropy_delta,
|
||||
0.5 * max_entropy)
|
||||
return EntropyPenaltyFunction.apply(attn_weights,
|
||||
self.num_heads,
|
||||
entropy_limit,
|
||||
self.grad_scale / self.prob)
|
||||
|
||||
|
||||
|
||||
|
||||
class RelPositionMultiheadAttention(nn.Module):
|
||||
r"""Multi-Head Attention layer with relative position encoding
|
||||
|
||||
@ -848,6 +970,16 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# linear transformation for positional encoding (projects to a scalar per head,
|
||||
# which will be added to the score).
|
||||
self.linear_pos = ScaledLinear(embed_dim, num_heads, initial_scale=0.05)
|
||||
self.entropy_penalty = EntropyPenalty(num_heads,
|
||||
entropy_delta=1.5,
|
||||
prob=1.0 if __name__ == "__main__" else 0.2,
|
||||
grad_scale=0.01)
|
||||
|
||||
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
|
||||
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = nn.Linear(embed_dim, num_heads, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1146,6 +1278,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
Returns:
|
||||
output of the same shape as x, i.e. (seq_len, batch_size, embed_dim)
|
||||
"""
|
||||
attn_weights = self.entropy_penalty(attn_weights)
|
||||
|
||||
num_heads = self.num_heads
|
||||
(seq_len, bsz, embed_dim) = x.shape
|
||||
head_dim = embed_dim // (num_heads * 2)
|
||||
@ -1154,6 +1288,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
# now v: (bsz * num_heads, seq_len, head_dim)
|
||||
attn_output = torch.bmm(attn_weights, v)
|
||||
|
||||
if random.random() < 0.001 or __name__ == "__main__":
|
||||
self._print_attn_stats(attn_weights, attn_output)
|
||||
|
||||
# attn_output: (bsz * num_heads, seq_len, head_dim)
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
@ -1164,6 +1302,38 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
return self.out_proj2(attn_output)
|
||||
|
||||
|
||||
def _print_attn_stats(
|
||||
self,
|
||||
attn_weights: Tensor,
|
||||
attn_output: Tensor):
|
||||
# attn_weights: (batch_size * num_heads, seq_len, seq_len)
|
||||
# attn_output: (bsz * num_heads, seq_len, head_dim)
|
||||
(n, seq_len, head_dim) = attn_output.shape
|
||||
num_heads = self.num_heads
|
||||
bsz = n // num_heads
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
attn_weights = attn_weights.to(torch.float32)
|
||||
attn_output = attn_output.to(torch.float32)
|
||||
attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum(
|
||||
dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2))
|
||||
attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim)
|
||||
attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim)
|
||||
attn_output_mean = attn_output.mean(dim=1, keepdim=True)
|
||||
attn_output = attn_output - attn_output_mean
|
||||
attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len)
|
||||
# attn_covar: (num_heads, head_dim, head_dim)
|
||||
#eigs, _ = torch.symeig(attn_covar)
|
||||
#logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}")
|
||||
|
||||
attn_covar = attn_covar.mean(dim=1).sum(dim=1) # (num_heads,)
|
||||
embed_dim = self.in_proj2.weight.shape[1]
|
||||
in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2))
|
||||
out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2))
|
||||
logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}")
|
||||
|
||||
|
||||
class FeedforwardModule(nn.Module):
|
||||
"""Feedforward module in Conformer model.
|
||||
"""
|
||||
@ -1537,7 +1707,7 @@ def _test_conformer_main():
|
||||
torch.randn(batch_size, seq_len, feature_dim),
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||
)
|
||||
f # to remove flake8 warnings
|
||||
f[0].sum().backward()
|
||||
c.eval()
|
||||
f = c(
|
||||
torch.randn(batch_size, seq_len, feature_dim),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user