mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add debug code for attention weihts and eigs
This commit is contained in:
parent
1825336841
commit
18ff1de337
@ -1212,6 +1212,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
# now v: (bsz * num_heads, seq_len, head_dim)
|
||||
attn_output = torch.bmm(attn_weights, v)
|
||||
|
||||
if random.random() < 0.001 or __name__ == "__main__":
|
||||
self._print_attn_stats(attn_weights, attn_output)
|
||||
|
||||
# attn_output: (bsz * num_heads, seq_len, head_dim)
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
@ -1222,6 +1226,33 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
return self.out_proj2(attn_output)
|
||||
|
||||
|
||||
def _print_attn_stats(
|
||||
self,
|
||||
attn_weights: Tensor,
|
||||
attn_output: Tensor):
|
||||
# attn_weights: (batch_size * num_heads, seq_len, seq_len)
|
||||
# attn_output: (bsz * num_heads, seq_len, head_dim)
|
||||
(n, seq_len, head_dim) = attn_output.shape
|
||||
num_heads = self.num_heads
|
||||
bsz = n // num_heads
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
attn_weights = attn_weights.to(torch.float32)
|
||||
xyz
|
||||
attn_output = attn_output.to(torch.float32)
|
||||
attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum(
|
||||
dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2))
|
||||
attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim)
|
||||
attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim)
|
||||
attn_output_mean = attn_output.mean(dim=1, keepdim=True)
|
||||
attn_output = attn_output - attn_output_mean
|
||||
attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len)
|
||||
# attn_covar: (num_heads, head_dim, head_dim)
|
||||
eigs, _ = torch.symeig(attn_covar)
|
||||
logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}")
|
||||
|
||||
|
||||
class FeedforwardModule(nn.Module):
|
||||
"""Feedforward module in Conformer model.
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user