Use normal implementation of softmax.

This commit is contained in:
Daniel Povey 2022-10-20 19:34:10 +08:00
parent 6e6209419c
commit 1018a77410

View File

@ -1171,7 +1171,7 @@ class RelPositionMultiheadAttention(nn.Module):
bsz * num_heads, seq_len, seq_len
)
attn_output_weights = softmax(attn_output_weights, dim=-1)
attn_output_weights = attn_output_weights.softmax(dim=-1)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training
)
@ -1584,7 +1584,7 @@ class AttentionCombine(nn.Module):
single_prob_mask)
weights = weights.masked_fill(mask, float('-inf'))
weights = softmax(weights, dim=1)
weights = weights.softmax(dim=1)
# (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1),
ans = torch.matmul(stacked_inputs, weights.unsqueeze(2))