Add debug code for attention weihts and eigs

This commit is contained in:
Daniel Povey 2022-10-14 20:57:17 +08:00
parent 1825336841
commit 18ff1de337

View File

@ -1212,6 +1212,10 @@ class RelPositionMultiheadAttention(nn.Module):
v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
# now v: (bsz * num_heads, seq_len, head_dim)
attn_output = torch.bmm(attn_weights, v)
if random.random() < 0.001 or __name__ == "__main__":
self._print_attn_stats(attn_weights, attn_output)
# attn_output: (bsz * num_heads, seq_len, head_dim)
attn_output = (
attn_output.transpose(0, 1)
@ -1222,6 +1226,33 @@ class RelPositionMultiheadAttention(nn.Module):
return self.out_proj2(attn_output)
def _print_attn_stats(
self,
attn_weights: Tensor,
attn_output: Tensor):
# attn_weights: (batch_size * num_heads, seq_len, seq_len)
# attn_output: (bsz * num_heads, seq_len, head_dim)
(n, seq_len, head_dim) = attn_output.shape
num_heads = self.num_heads
bsz = n // num_heads
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32)
xyz
attn_output = attn_output.to(torch.float32)
attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum(
dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2))
attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim)
attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim)
attn_output_mean = attn_output.mean(dim=1, keepdim=True)
attn_output = attn_output - attn_output_mean
attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len)
# attn_covar: (num_heads, head_dim, head_dim)
eigs, _ = torch.symeig(attn_covar)
logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}")
class FeedforwardModule(nn.Module):
"""Feedforward module in Conformer model.
"""