Merge branch 'scaled_adam_exp117' into scaled_adam_exp119

# Conflicts:
#	egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py
This commit is contained in:
Daniel Povey 2022-10-15 14:34:56 +08:00
commit 125e1b167c

View File

@ -801,6 +801,128 @@ class RelPositionalEncoding(torch.nn.Module):
return self.dropout(pos_emb)
class EntropyPenaltyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
attn_weights: Tensor,
num_heads: int,
entropy_limit: float,
grad_scale: float) -> Tensor:
ctx.save_for_backward(attn_weights)
ctx.num_heads = num_heads
ctx.entropy_limit = entropy_limit
ctx.grad_scale = grad_scale
return attn_weights
@staticmethod
def backward(ctx,
attn_weights_grad: Tensor):
attn_weights, = ctx.saved_tensors
num_heads = ctx.num_heads
entropy_limit = ctx.entropy_limit
grad_scale = ctx.grad_scale
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
attn_weights_orig = attn_weights.to(torch.float32).detach()
attn_weights_orig.requires_grad = True
bsz = attn_weights_orig.shape[0] // num_heads
seq_len = attn_weights_orig.shape[2]
attn_weights = attn_weights_orig.reshape(bsz, num_heads,
seq_len, seq_len)
grad_norms = attn_weights_grad.detach().reshape(
bsz, num_heads, seq_len * seq_len).norm(dim=(0,2))
entropy = ((attn_weights + 1.0e-20).log() * attn_weights).sum(dim=-1)
# entropy: (bsz, num_heads, seq_len)
entropy = -entropy.mean(dim=(0,2))
# entropy: (num_heads,)
assert entropy.shape == (num_heads,)
excess_entropy = (entropy - entropy_limit).relu()
above_cutoff = (entropy > 0) # tensor of shape (num_heads,)
small_grad_norm = (grad_norms < grad_norms.mean())
will_penalize = torch.logical_and(above_cutoff, small_grad_norm)
if random.random() < 0.005 or __name__ == "__main__":
logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}")
will_penalize_sum = will_penalize.to(torch.float32).sum().item()
if will_penalize_sum == 0:
# grad would be 0. I'm guessing that checking this, and
# incurring a CUDA sync, may save time relative to doing the
# backprop of the entropy, but I'm not sure.
return attn_weights_grad, None, None, None
# Treat `excess_entropy` as a loss, to be minimized.
excess_entropy.backward(gradient=will_penalize.to(torch.float32))
entropy_grad = attn_weights_orig.grad
scale = ((grad_scale * will_penalize_sum / num_heads) *
(attn_weights_grad.to(torch.float32).norm() /
(entropy_grad.norm() + 1.0e-20)))
entropy_grad = entropy_grad * scale
return attn_weights_grad + entropy_grad.to(attn_weights_grad.dtype), None, None, None
class EntropyPenalty(nn.Module):
def __init__(
self,
num_heads: float,
entropy_delta: float,
prob: float,
grad_scale: float):
"""
Args:
num_heads: the number of attention heads in the self-attention module that
this is attached to.
entropy_delta: the delta from the maximum entropy, that we aim to
decrease the entropy to if it is above. So the maximum entropy
should be max(log(seq_len) - entropy_cutoff, 0.5 * log(seq_len));
the second term is to make sure the limit never becomes tiny or
negative in the case of short sequences.
prob: the probability with which we apply this object.
grad_scale: determines the scale on the gradient term from this object,
relative to the rest of the gradient on the attention weights;
will be divided by `prob`.
"""
super(EntropyPenalty, self).__init__()
self.num_heads = num_heads
self.entropy_delta = entropy_delta
self.prob = prob
self.grad_scale = grad_scale
def forward(self,
attn_weights: Tensor) -> Tensor:
"""
In the forward pass, this function just returns the attention weights.
In the backward pass, it will modify the gradients to ensure that the
entropy of the attention heads is not too large. (We have noticed
that too-large/almost-maximal entropy in the attention distribution
is associated with heads that are not doing anything useful.
Args:
attn_weights: the attention weights, after the log, with shape
(batch_size * num_heads, seq_len, seq_len), satisfying:
attn_weights.sum(dim=-1) == 1.
Returns:
the attn_weights, without any change. You should make sure
you use the returned attention weights, or the graph will be freed
and nothing will happen in backprop.
"""
if not attn_weights.requires_grad or random.random() > self.prob:
return attn_weights
else:
seq_len = attn_weights.shape[2]
max_entropy = math.log(seq_len)
entropy_limit = max(max_entropy - self.entropy_delta,
0.5 * max_entropy)
return EntropyPenaltyFunction.apply(attn_weights,
self.num_heads,
entropy_limit,
self.grad_scale / self.prob)
class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding
@ -848,6 +970,16 @@ class RelPositionMultiheadAttention(nn.Module):
# linear transformation for positional encoding (projects to a scalar per head,
# which will be added to the score).
self.linear_pos = ScaledLinear(embed_dim, num_heads, initial_scale=0.05)
self.entropy_penalty = EntropyPenalty(num_heads,
entropy_delta=1.5,
prob=1.0 if __name__ == "__main__" else 0.2,
grad_scale=0.01)
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, num_heads, bias=False)
def forward(
self,
@ -1146,6 +1278,8 @@ class RelPositionMultiheadAttention(nn.Module):
Returns:
output of the same shape as x, i.e. (seq_len, batch_size, embed_dim)
"""
attn_weights = self.entropy_penalty(attn_weights)
num_heads = self.num_heads
(seq_len, bsz, embed_dim) = x.shape
head_dim = embed_dim // (num_heads * 2)
@ -1154,6 +1288,10 @@ class RelPositionMultiheadAttention(nn.Module):
v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
# now v: (bsz * num_heads, seq_len, head_dim)
attn_output = torch.bmm(attn_weights, v)
if random.random() < 0.001 or __name__ == "__main__":
self._print_attn_stats(attn_weights, attn_output)
# attn_output: (bsz * num_heads, seq_len, head_dim)
attn_output = (
attn_output.transpose(0, 1)
@ -1164,6 +1302,38 @@ class RelPositionMultiheadAttention(nn.Module):
return self.out_proj2(attn_output)
def _print_attn_stats(
self,
attn_weights: Tensor,
attn_output: Tensor):
# attn_weights: (batch_size * num_heads, seq_len, seq_len)
# attn_output: (bsz * num_heads, seq_len, head_dim)
(n, seq_len, head_dim) = attn_output.shape
num_heads = self.num_heads
bsz = n // num_heads
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32)
attn_output = attn_output.to(torch.float32)
attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum(
dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2))
attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim)
attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim)
attn_output_mean = attn_output.mean(dim=1, keepdim=True)
attn_output = attn_output - attn_output_mean
attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len)
# attn_covar: (num_heads, head_dim, head_dim)
#eigs, _ = torch.symeig(attn_covar)
#logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}")
attn_covar = attn_covar.mean(dim=1).sum(dim=1) # (num_heads,)
embed_dim = self.in_proj2.weight.shape[1]
in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2))
out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2))
logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}")
class FeedforwardModule(nn.Module):
"""Feedforward module in Conformer model.
"""
@ -1537,7 +1707,7 @@ def _test_conformer_main():
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
)
f # to remove flake8 warnings
f[0].sum().backward()
c.eval()
f = c(
torch.randn(batch_size, seq_len, feature_dim),