From 18ff1de337b5f16191dd2fd46a39e909902922e9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Oct 2022 20:57:17 +0800 Subject: [PATCH] Add debug code for attention weihts and eigs --- .../pruned_transducer_stateless7/conformer.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index cef8d1b18..b04a695cd 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -1212,6 +1212,10 @@ class RelPositionMultiheadAttention(nn.Module): v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1) # now v: (bsz * num_heads, seq_len, head_dim) attn_output = torch.bmm(attn_weights, v) + + if random.random() < 0.001 or __name__ == "__main__": + self._print_attn_stats(attn_weights, attn_output) + # attn_output: (bsz * num_heads, seq_len, head_dim) attn_output = ( attn_output.transpose(0, 1) @@ -1222,6 +1226,33 @@ class RelPositionMultiheadAttention(nn.Module): return self.out_proj2(attn_output) + def _print_attn_stats( + self, + attn_weights: Tensor, + attn_output: Tensor): + # attn_weights: (batch_size * num_heads, seq_len, seq_len) + # attn_output: (bsz * num_heads, seq_len, head_dim) + (n, seq_len, head_dim) = attn_output.shape + num_heads = self.num_heads + bsz = n // num_heads + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + xyz + attn_output = attn_output.to(torch.float32) + attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( + dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) + attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim) + attn_output_mean = attn_output.mean(dim=1, keepdim=True) + attn_output = attn_output - attn_output_mean + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) + # attn_covar: (num_heads, head_dim, head_dim) + eigs, _ = torch.symeig(attn_covar) + logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + + class FeedforwardModule(nn.Module): """Feedforward module in Conformer model. """