Fix debug stats.
This commit is contained in:
parent
252798b6a1
commit
9919a05612
@ -1342,7 +1342,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
#eigs, _ = torch.symeig(attn_covar)
|
||||
#logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}")
|
||||
|
||||
attn_covar = attn_covar.mean(dim=1).sum(dim=1) # (num_heads,)
|
||||
attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,)
|
||||
embed_dim = self.in_proj2.weight.shape[1]
|
||||
in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2))
|
||||
out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user