From 18ff1de337b5f16191dd2fd46a39e909902922e9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Oct 2022 20:57:17 +0800 Subject: [PATCH 1/8] Add debug code for attention weihts and eigs --- .../pruned_transducer_stateless7/conformer.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index cef8d1b18..b04a695cd 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -1212,6 +1212,10 @@ class RelPositionMultiheadAttention(nn.Module): v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1) # now v: (bsz * num_heads, seq_len, head_dim) attn_output = torch.bmm(attn_weights, v) + + if random.random() < 0.001 or __name__ == "__main__": + self._print_attn_stats(attn_weights, attn_output) + # attn_output: (bsz * num_heads, seq_len, head_dim) attn_output = ( attn_output.transpose(0, 1) @@ -1222,6 +1226,33 @@ class RelPositionMultiheadAttention(nn.Module): return self.out_proj2(attn_output) + def _print_attn_stats( + self, + attn_weights: Tensor, + attn_output: Tensor): + # attn_weights: (batch_size * num_heads, seq_len, seq_len) + # attn_output: (bsz * num_heads, seq_len, head_dim) + (n, seq_len, head_dim) = attn_output.shape + num_heads = self.num_heads + bsz = n // num_heads + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + xyz + attn_output = attn_output.to(torch.float32) + attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( + dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) + attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim) + attn_output_mean = attn_output.mean(dim=1, keepdim=True) + attn_output = attn_output - attn_output_mean + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) + # attn_covar: (num_heads, head_dim, head_dim) + eigs, _ = torch.symeig(attn_covar) + logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + + class FeedforwardModule(nn.Module): """Feedforward module in Conformer model. """ From 90953537adb2e5427c799d3a859026368f118b8d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Oct 2022 20:59:26 +0800 Subject: [PATCH 2/8] Remove debug statement --- egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index b04a695cd..798c316a5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -1239,7 +1239,6 @@ class RelPositionMultiheadAttention(nn.Module): with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) - xyz attn_output = attn_output.to(torch.float32) attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) From 1812f6cb28984e2439ca854bce0b00236dcdf549 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Oct 2022 21:16:23 +0800 Subject: [PATCH 3/8] Add different debug info. --- .../ASR/pruned_transducer_stateless7/conformer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 798c316a5..a4d911e1f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -1248,8 +1248,14 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = attn_output - attn_output_mean attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) # attn_covar: (num_heads, head_dim, head_dim) - eigs, _ = torch.symeig(attn_covar) - logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + #eigs, _ = torch.symeig(attn_covar) + #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + + attn_covar = attn_covar.mean(dim=1).sum(dim=1) # (num_heads,) + embed_dim = self.in_proj2.weight.shape[1] + in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) + out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2)) + logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") class FeedforwardModule(nn.Module): From a780984e6b2338282b70d14eccc55f18c6767d8f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Oct 2022 23:01:30 +0800 Subject: [PATCH 4/8] Penalize attention-weight entropies above a limit. --- .../pruned_transducer_stateless7/conformer.py | 130 +++++++++++++++++- 1 file changed, 129 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index a4d911e1f..90449617d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -805,6 +805,127 @@ class RelPositionalEncoding(torch.nn.Module): return self.dropout(pos_emb) +class EntropyPenaltyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, + attn_weights: Tensor, + num_heads: int, + entropy_limit: float, + grad_scale: float) -> Tensor: + logging.info("Here3") + ctx.save_for_backward(attn_weights) + ctx.num_heads = num_heads + ctx.entropy_limit = entropy_limit + ctx.grad_scale = grad_scale + return attn_weights + + @staticmethod + def backward(ctx, + attn_weights_grad: Tensor): + attn_weights, = ctx.saved_tensors + num_heads = ctx.num_heads + entropy_limit = ctx.entropy_limit + grad_scale = ctx.grad_scale + logging.info("Here4") + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights_orig = attn_weights.to(torch.float32).detach() + attn_weights_orig.requires_grad = True + bsz = attn_weights_orig.shape[0] // num_heads + seq_len = attn_weights_orig.shape[2] + attn_weights = attn_weights_orig.reshape(bsz, num_heads, + seq_len, seq_len) + entropy = ((attn_weights + 1.0e-20).log() * attn_weights).sum(dim=-1) + # entropy: (bsz, num_heads, seq_len) + entropy = -entropy.mean(dim=(0,2)) + # entropy: (num_heads,) + assert entropy.shape == (num_heads,) + excess_entropy = (entropy - entropy_limit).relu() + above_cutoff = (excess_entropy != 0) # tensor of shape (num_heads,) + if random.random() < 0.001 or __name__ == "__main__": + logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}") + above_cutoff_sum = above_cutoff.to(torch.float32).sum() + above_cutoff_sum = above_cutoff_sum.item() + if above_cutoff_sum == 0: + # grad would be 0. I'm guessing that checking this, and + # incurring a CUDA sync, may save time relative to doing the + # backprop of the entropy, but I'm not sure. + return attn_weights_grad, None, None, None + # Treat `excess_entropy` as a loss, to be minimized. + excess_entropy.backward(gradient=torch.ones_like(excess_entropy)) + entropy_grad = attn_weights_orig.grad + scale = ((grad_scale * above_cutoff_sum / num_heads) * + (attn_weights_grad.to(torch.float32).norm() / + (entropy_grad.norm() + 1.0e-20))) + entropy_grad = entropy_grad * scale + return attn_weights_grad + entropy_grad.to(attn_weights_grad.dtype), None, None, None + + + + + +class EntropyPenalty(nn.Module): + def __init__( + self, + num_heads: float, + entropy_delta: float, + prob: float, + grad_scale: float): + """ + Args: + num_heads: the number of attention heads in the self-attention module that + this is attached to. + entropy_delta: the delta from the maximum entropy, that we aim to + decrease the entropy to if it is above. So the maximum entropy + should be max(log(seq_len) - entropy_cutoff, 0.5 * log(seq_len)); + the second term is to make sure the limit never becomes tiny or + negative in the case of short sequences. + prob: the probability with which we apply this object. + grad_scale: determines the scale on the gradient term from this object, + relative to the rest of the gradient on the attention weights; + will be divided by `prob`. + """ + super(EntropyPenalty, self).__init__() + self.num_heads = num_heads + self.entropy_delta = entropy_delta + self.prob = prob + self.grad_scale = grad_scale + + def forward(self, + attn_weights: Tensor) -> Tensor: + """ + In the forward pass, this function just returns the attention weights. + In the backward pass, it will modify the gradients to ensure that the + entropy of the attention heads is not too large. (We have noticed + that too-large/almost-maximal entropy in the attention distribution + is associated with heads that are not doing anything useful. + + Args: + attn_weights: the attention weights, after the log, with shape + (batch_size * num_heads, seq_len, seq_len), satisfying: + attn_weights.sum(dim=-1) == 1. + Returns: + the attn_weights, without any change. You should make sure + you use the returned attention weights, or the graph will be freed + and nothing will happen in backprop. + """ + logging.info("Here1") + if not attn_weights.requires_grad or random.random() > self.prob: + logging.info("Here2") + return attn_weights + else: + seq_len = attn_weights.shape[2] + max_entropy = math.log(seq_len) + entropy_limit = max(max_entropy - self.entropy_delta, + 0.5 * max_entropy) + return EntropyPenaltyFunction.apply(attn_weights, + self.num_heads, + entropy_limit, + self.grad_scale / self.prob) + + + + class RelPositionMultiheadAttention(nn.Module): r"""Multi-Head Attention layer with relative position encoding @@ -851,6 +972,11 @@ class RelPositionMultiheadAttention(nn.Module): self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True, initial_scale=0.05) + self.entropy_penalty = EntropyPenalty(num_heads, + entropy_delta=0.8, + prob=1.0 if __name__ == "__main__" else 0.2, + grad_scale=0.01) + self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads)) self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads)) @@ -1204,6 +1330,8 @@ class RelPositionMultiheadAttention(nn.Module): Returns: output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) """ + attn_weights = self.entropy_penalty(attn_weights) + num_heads = self.num_heads (seq_len, bsz, embed_dim) = x.shape head_dim = embed_dim // (num_heads * 2) @@ -1631,7 +1759,7 @@ def _test_conformer_main(): torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), ) - f # to remove flake8 warnings + f[0].sum().backward() c.eval() f = c( torch.randn(batch_size, seq_len, feature_dim), From 394d4c95f901e1fa1435a991adcc888f4e289cf3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Oct 2022 23:09:05 +0800 Subject: [PATCH 5/8] Remove debug statements --- egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 90449617d..673470f59 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -812,7 +812,6 @@ class EntropyPenaltyFunction(torch.autograd.Function): num_heads: int, entropy_limit: float, grad_scale: float) -> Tensor: - logging.info("Here3") ctx.save_for_backward(attn_weights) ctx.num_heads = num_heads ctx.entropy_limit = entropy_limit @@ -826,7 +825,6 @@ class EntropyPenaltyFunction(torch.autograd.Function): num_heads = ctx.num_heads entropy_limit = ctx.entropy_limit grad_scale = ctx.grad_scale - logging.info("Here4") with torch.enable_grad(): with torch.cuda.amp.autocast(enabled=False): attn_weights_orig = attn_weights.to(torch.float32).detach() @@ -909,9 +907,7 @@ class EntropyPenalty(nn.Module): you use the returned attention weights, or the graph will be freed and nothing will happen in backprop. """ - logging.info("Here1") if not attn_weights.requires_grad or random.random() > self.prob: - logging.info("Here2") return attn_weights else: seq_len = attn_weights.shape[2] From 0557dbb72076465a4edb35250d5f316a11331c6b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Oct 2022 23:23:20 +0800 Subject: [PATCH 6/8] use larger delta but only penalize if small grad norm --- .../pruned_transducer_stateless7/conformer.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 673470f59..a0f7ade04 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -833,26 +833,32 @@ class EntropyPenaltyFunction(torch.autograd.Function): seq_len = attn_weights_orig.shape[2] attn_weights = attn_weights_orig.reshape(bsz, num_heads, seq_len, seq_len) + + grad_norms = attn_weights_grad.detach().reshape( + bsz, num_heads, seq_len * seq_len).norm(dim=(0,2)) + entropy = ((attn_weights + 1.0e-20).log() * attn_weights).sum(dim=-1) # entropy: (bsz, num_heads, seq_len) entropy = -entropy.mean(dim=(0,2)) # entropy: (num_heads,) assert entropy.shape == (num_heads,) excess_entropy = (entropy - entropy_limit).relu() - above_cutoff = (excess_entropy != 0) # tensor of shape (num_heads,) + above_cutoff = (entropy > 0) # tensor of shape (num_heads,) + small_grad_norm = (grad_norms < 0.5 * grad_norms.mean()) + will_penalize = torch.logical_and(above_cutoff, small_grad_norm) if random.random() < 0.001 or __name__ == "__main__": - logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}") - above_cutoff_sum = above_cutoff.to(torch.float32).sum() - above_cutoff_sum = above_cutoff_sum.item() - if above_cutoff_sum == 0: + logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}") + will_penalize_sum = will_penalize.to(torch.float32).sum() + will_penalize_sum = will_penalize_sum.item() + if will_penalize_sum == 0: # grad would be 0. I'm guessing that checking this, and # incurring a CUDA sync, may save time relative to doing the # backprop of the entropy, but I'm not sure. return attn_weights_grad, None, None, None # Treat `excess_entropy` as a loss, to be minimized. - excess_entropy.backward(gradient=torch.ones_like(excess_entropy)) + excess_entropy.backward(gradient=will_penalize_sum.to(torch.float32)) entropy_grad = attn_weights_orig.grad - scale = ((grad_scale * above_cutoff_sum / num_heads) * + scale = ((grad_scale * will_penalize_sum / num_heads) * (attn_weights_grad.to(torch.float32).norm() / (entropy_grad.norm() + 1.0e-20))) entropy_grad = entropy_grad * scale @@ -969,7 +975,7 @@ class RelPositionMultiheadAttention(nn.Module): initial_scale=0.05) self.entropy_penalty = EntropyPenalty(num_heads, - entropy_delta=0.8, + entropy_delta=1.5, prob=1.0 if __name__ == "__main__" else 0.2, grad_scale=0.01) From 822465f73bf8f6a931ce524a840127dd07144328 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Oct 2022 23:25:29 +0800 Subject: [PATCH 7/8] Bug fixes; change debug freq --- .../ASR/pruned_transducer_stateless7/conformer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index a0f7ade04..6183a50d4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -846,17 +846,16 @@ class EntropyPenaltyFunction(torch.autograd.Function): above_cutoff = (entropy > 0) # tensor of shape (num_heads,) small_grad_norm = (grad_norms < 0.5 * grad_norms.mean()) will_penalize = torch.logical_and(above_cutoff, small_grad_norm) - if random.random() < 0.001 or __name__ == "__main__": + if random.random() < 0.005 or __name__ == "__main__": logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}") - will_penalize_sum = will_penalize.to(torch.float32).sum() - will_penalize_sum = will_penalize_sum.item() + will_penalize_sum = will_penalize.to(torch.float32).sum().item() if will_penalize_sum == 0: # grad would be 0. I'm guessing that checking this, and # incurring a CUDA sync, may save time relative to doing the # backprop of the entropy, but I'm not sure. return attn_weights_grad, None, None, None # Treat `excess_entropy` as a loss, to be minimized. - excess_entropy.backward(gradient=will_penalize_sum.to(torch.float32)) + excess_entropy.backward(gradient=will_penalize.to(torch.float32)) entropy_grad = attn_weights_orig.grad scale = ((grad_scale * will_penalize_sum / num_heads) * (attn_weights_grad.to(torch.float32).norm() / From 80d51efd154520a38e77f0568b4481c3515a15b9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Oct 2022 23:29:55 +0800 Subject: [PATCH 8/8] Change cutoff for small_grad_norm --- egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 6183a50d4..c8de95cab 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -844,7 +844,7 @@ class EntropyPenaltyFunction(torch.autograd.Function): assert entropy.shape == (num_heads,) excess_entropy = (entropy - entropy_limit).relu() above_cutoff = (entropy > 0) # tensor of shape (num_heads,) - small_grad_norm = (grad_norms < 0.5 * grad_norms.mean()) + small_grad_norm = (grad_norms < grad_norms.mean()) will_penalize = torch.logical_and(above_cutoff, small_grad_norm) if random.random() < 0.005 or __name__ == "__main__": logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}")