diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 10527a7a5..a944597b0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -801,6 +801,128 @@ class RelPositionalEncoding(torch.nn.Module): return self.dropout(pos_emb) +class EntropyPenaltyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, + attn_weights: Tensor, + num_heads: int, + entropy_limit: float, + grad_scale: float) -> Tensor: + ctx.save_for_backward(attn_weights) + ctx.num_heads = num_heads + ctx.entropy_limit = entropy_limit + ctx.grad_scale = grad_scale + return attn_weights + + @staticmethod + def backward(ctx, + attn_weights_grad: Tensor): + attn_weights, = ctx.saved_tensors + num_heads = ctx.num_heads + entropy_limit = ctx.entropy_limit + grad_scale = ctx.grad_scale + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights_orig = attn_weights.to(torch.float32).detach() + attn_weights_orig.requires_grad = True + bsz = attn_weights_orig.shape[0] // num_heads + seq_len = attn_weights_orig.shape[2] + attn_weights = attn_weights_orig.reshape(bsz, num_heads, + seq_len, seq_len) + + grad_norms = attn_weights_grad.detach().reshape( + bsz, num_heads, seq_len * seq_len).norm(dim=(0,2)) + + entropy = ((attn_weights + 1.0e-20).log() * attn_weights).sum(dim=-1) + # entropy: (bsz, num_heads, seq_len) + entropy = -entropy.mean(dim=(0,2)) + # entropy: (num_heads,) + assert entropy.shape == (num_heads,) + excess_entropy = (entropy - entropy_limit).relu() + above_cutoff = (entropy > 0) # tensor of shape (num_heads,) + small_grad_norm = (grad_norms < grad_norms.mean()) + will_penalize = torch.logical_and(above_cutoff, small_grad_norm) + if random.random() < 0.005 or __name__ == "__main__": + logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}") + will_penalize_sum = will_penalize.to(torch.float32).sum().item() + if will_penalize_sum == 0: + # grad would be 0. I'm guessing that checking this, and + # incurring a CUDA sync, may save time relative to doing the + # backprop of the entropy, but I'm not sure. + return attn_weights_grad, None, None, None + # Treat `excess_entropy` as a loss, to be minimized. + excess_entropy.backward(gradient=will_penalize.to(torch.float32)) + entropy_grad = attn_weights_orig.grad + scale = ((grad_scale * will_penalize_sum / num_heads) * + (attn_weights_grad.to(torch.float32).norm() / + (entropy_grad.norm() + 1.0e-20))) + entropy_grad = entropy_grad * scale + return attn_weights_grad + entropy_grad.to(attn_weights_grad.dtype), None, None, None + + + + + +class EntropyPenalty(nn.Module): + def __init__( + self, + num_heads: float, + entropy_delta: float, + prob: float, + grad_scale: float): + """ + Args: + num_heads: the number of attention heads in the self-attention module that + this is attached to. + entropy_delta: the delta from the maximum entropy, that we aim to + decrease the entropy to if it is above. So the maximum entropy + should be max(log(seq_len) - entropy_cutoff, 0.5 * log(seq_len)); + the second term is to make sure the limit never becomes tiny or + negative in the case of short sequences. + prob: the probability with which we apply this object. + grad_scale: determines the scale on the gradient term from this object, + relative to the rest of the gradient on the attention weights; + will be divided by `prob`. + """ + super(EntropyPenalty, self).__init__() + self.num_heads = num_heads + self.entropy_delta = entropy_delta + self.prob = prob + self.grad_scale = grad_scale + + def forward(self, + attn_weights: Tensor) -> Tensor: + """ + In the forward pass, this function just returns the attention weights. + In the backward pass, it will modify the gradients to ensure that the + entropy of the attention heads is not too large. (We have noticed + that too-large/almost-maximal entropy in the attention distribution + is associated with heads that are not doing anything useful. + + Args: + attn_weights: the attention weights, after the log, with shape + (batch_size * num_heads, seq_len, seq_len), satisfying: + attn_weights.sum(dim=-1) == 1. + Returns: + the attn_weights, without any change. You should make sure + you use the returned attention weights, or the graph will be freed + and nothing will happen in backprop. + """ + if not attn_weights.requires_grad or random.random() > self.prob: + return attn_weights + else: + seq_len = attn_weights.shape[2] + max_entropy = math.log(seq_len) + entropy_limit = max(max_entropy - self.entropy_delta, + 0.5 * max_entropy) + return EntropyPenaltyFunction.apply(attn_weights, + self.num_heads, + entropy_limit, + self.grad_scale / self.prob) + + + + class RelPositionMultiheadAttention(nn.Module): r"""Multi-Head Attention layer with relative position encoding @@ -848,6 +970,16 @@ class RelPositionMultiheadAttention(nn.Module): # linear transformation for positional encoding (projects to a scalar per head, # which will be added to the score). self.linear_pos = ScaledLinear(embed_dim, num_heads, initial_scale=0.05) + self.entropy_penalty = EntropyPenalty(num_heads, + entropy_delta=1.5, + prob=1.0 if __name__ == "__main__" else 0.2, + grad_scale=0.01) + + self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads)) + self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads)) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, num_heads, bias=False) def forward( self, @@ -1146,6 +1278,8 @@ class RelPositionMultiheadAttention(nn.Module): Returns: output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) """ + attn_weights = self.entropy_penalty(attn_weights) + num_heads = self.num_heads (seq_len, bsz, embed_dim) = x.shape head_dim = embed_dim // (num_heads * 2) @@ -1154,6 +1288,10 @@ class RelPositionMultiheadAttention(nn.Module): v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1) # now v: (bsz * num_heads, seq_len, head_dim) attn_output = torch.bmm(attn_weights, v) + + if random.random() < 0.001 or __name__ == "__main__": + self._print_attn_stats(attn_weights, attn_output) + # attn_output: (bsz * num_heads, seq_len, head_dim) attn_output = ( attn_output.transpose(0, 1) @@ -1164,6 +1302,38 @@ class RelPositionMultiheadAttention(nn.Module): return self.out_proj2(attn_output) + def _print_attn_stats( + self, + attn_weights: Tensor, + attn_output: Tensor): + # attn_weights: (batch_size * num_heads, seq_len, seq_len) + # attn_output: (bsz * num_heads, seq_len, head_dim) + (n, seq_len, head_dim) = attn_output.shape + num_heads = self.num_heads + bsz = n // num_heads + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_output = attn_output.to(torch.float32) + attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( + dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) + attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim) + attn_output_mean = attn_output.mean(dim=1, keepdim=True) + attn_output = attn_output - attn_output_mean + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) + # attn_covar: (num_heads, head_dim, head_dim) + #eigs, _ = torch.symeig(attn_covar) + #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + + attn_covar = attn_covar.mean(dim=1).sum(dim=1) # (num_heads,) + embed_dim = self.in_proj2.weight.shape[1] + in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) + out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2)) + logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") + + class FeedforwardModule(nn.Module): """Feedforward module in Conformer model. """ @@ -1537,7 +1707,7 @@ def _test_conformer_main(): torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), ) - f # to remove flake8 warnings + f[0].sum().backward() c.eval() f = c( torch.randn(batch_size, seq_len, feature_dim),