From 9919a056125ebb9c310b06e03b256be04b3e9457 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 15 Oct 2022 16:47:46 +0800 Subject: [PATCH] Fix debug stats. --- egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index e7b5df4ef..0093b644c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -1342,7 +1342,7 @@ class RelPositionMultiheadAttention(nn.Module): #eigs, _ = torch.symeig(attn_covar) #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") - attn_covar = attn_covar.mean(dim=1).sum(dim=1) # (num_heads,) + attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) embed_dim = self.in_proj2.weight.shape[1] in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2))