diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index a4d911e1f..90449617d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -805,6 +805,127 @@ class RelPositionalEncoding(torch.nn.Module): return self.dropout(pos_emb) +class EntropyPenaltyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, + attn_weights: Tensor, + num_heads: int, + entropy_limit: float, + grad_scale: float) -> Tensor: + logging.info("Here3") + ctx.save_for_backward(attn_weights) + ctx.num_heads = num_heads + ctx.entropy_limit = entropy_limit + ctx.grad_scale = grad_scale + return attn_weights + + @staticmethod + def backward(ctx, + attn_weights_grad: Tensor): + attn_weights, = ctx.saved_tensors + num_heads = ctx.num_heads + entropy_limit = ctx.entropy_limit + grad_scale = ctx.grad_scale + logging.info("Here4") + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights_orig = attn_weights.to(torch.float32).detach() + attn_weights_orig.requires_grad = True + bsz = attn_weights_orig.shape[0] // num_heads + seq_len = attn_weights_orig.shape[2] + attn_weights = attn_weights_orig.reshape(bsz, num_heads, + seq_len, seq_len) + entropy = ((attn_weights + 1.0e-20).log() * attn_weights).sum(dim=-1) + # entropy: (bsz, num_heads, seq_len) + entropy = -entropy.mean(dim=(0,2)) + # entropy: (num_heads,) + assert entropy.shape == (num_heads,) + excess_entropy = (entropy - entropy_limit).relu() + above_cutoff = (excess_entropy != 0) # tensor of shape (num_heads,) + if random.random() < 0.001 or __name__ == "__main__": + logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}") + above_cutoff_sum = above_cutoff.to(torch.float32).sum() + above_cutoff_sum = above_cutoff_sum.item() + if above_cutoff_sum == 0: + # grad would be 0. I'm guessing that checking this, and + # incurring a CUDA sync, may save time relative to doing the + # backprop of the entropy, but I'm not sure. + return attn_weights_grad, None, None, None + # Treat `excess_entropy` as a loss, to be minimized. + excess_entropy.backward(gradient=torch.ones_like(excess_entropy)) + entropy_grad = attn_weights_orig.grad + scale = ((grad_scale * above_cutoff_sum / num_heads) * + (attn_weights_grad.to(torch.float32).norm() / + (entropy_grad.norm() + 1.0e-20))) + entropy_grad = entropy_grad * scale + return attn_weights_grad + entropy_grad.to(attn_weights_grad.dtype), None, None, None + + + + + +class EntropyPenalty(nn.Module): + def __init__( + self, + num_heads: float, + entropy_delta: float, + prob: float, + grad_scale: float): + """ + Args: + num_heads: the number of attention heads in the self-attention module that + this is attached to. + entropy_delta: the delta from the maximum entropy, that we aim to + decrease the entropy to if it is above. So the maximum entropy + should be max(log(seq_len) - entropy_cutoff, 0.5 * log(seq_len)); + the second term is to make sure the limit never becomes tiny or + negative in the case of short sequences. + prob: the probability with which we apply this object. + grad_scale: determines the scale on the gradient term from this object, + relative to the rest of the gradient on the attention weights; + will be divided by `prob`. + """ + super(EntropyPenalty, self).__init__() + self.num_heads = num_heads + self.entropy_delta = entropy_delta + self.prob = prob + self.grad_scale = grad_scale + + def forward(self, + attn_weights: Tensor) -> Tensor: + """ + In the forward pass, this function just returns the attention weights. + In the backward pass, it will modify the gradients to ensure that the + entropy of the attention heads is not too large. (We have noticed + that too-large/almost-maximal entropy in the attention distribution + is associated with heads that are not doing anything useful. + + Args: + attn_weights: the attention weights, after the log, with shape + (batch_size * num_heads, seq_len, seq_len), satisfying: + attn_weights.sum(dim=-1) == 1. + Returns: + the attn_weights, without any change. You should make sure + you use the returned attention weights, or the graph will be freed + and nothing will happen in backprop. + """ + logging.info("Here1") + if not attn_weights.requires_grad or random.random() > self.prob: + logging.info("Here2") + return attn_weights + else: + seq_len = attn_weights.shape[2] + max_entropy = math.log(seq_len) + entropy_limit = max(max_entropy - self.entropy_delta, + 0.5 * max_entropy) + return EntropyPenaltyFunction.apply(attn_weights, + self.num_heads, + entropy_limit, + self.grad_scale / self.prob) + + + + class RelPositionMultiheadAttention(nn.Module): r"""Multi-Head Attention layer with relative position encoding @@ -851,6 +972,11 @@ class RelPositionMultiheadAttention(nn.Module): self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True, initial_scale=0.05) + self.entropy_penalty = EntropyPenalty(num_heads, + entropy_delta=0.8, + prob=1.0 if __name__ == "__main__" else 0.2, + grad_scale=0.01) + self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads)) self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads)) @@ -1204,6 +1330,8 @@ class RelPositionMultiheadAttention(nn.Module): Returns: output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) """ + attn_weights = self.entropy_penalty(attn_weights) + num_heads = self.num_heads (seq_len, bsz, embed_dim) = x.shape head_dim = embed_dim // (num_heads * 2) @@ -1631,7 +1759,7 @@ def _test_conformer_main(): torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), ) - f # to remove flake8 warnings + f[0].sum().backward() c.eval() f = c( torch.randn(batch_size, seq_len, feature_dim),