From 1812f6cb28984e2439ca854bce0b00236dcdf549 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Oct 2022 21:16:23 +0800 Subject: [PATCH] Add different debug info. --- .../ASR/pruned_transducer_stateless7/conformer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 798c316a5..a4d911e1f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -1248,8 +1248,14 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = attn_output - attn_output_mean attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) # attn_covar: (num_heads, head_dim, head_dim) - eigs, _ = torch.symeig(attn_covar) - logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + #eigs, _ = torch.symeig(attn_covar) + #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + + attn_covar = attn_covar.mean(dim=1).sum(dim=1) # (num_heads,) + embed_dim = self.in_proj2.weight.shape[1] + in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) + out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2)) + logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") class FeedforwardModule(nn.Module):