mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use normal implementation of softmax.
This commit is contained in:
parent
6e6209419c
commit
1018a77410
@ -1171,7 +1171,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
bsz * num_heads, seq_len, seq_len
|
||||
)
|
||||
|
||||
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
||||
attn_output_weights = attn_output_weights.softmax(dim=-1)
|
||||
attn_output_weights = nn.functional.dropout(
|
||||
attn_output_weights, p=dropout_p, training=training
|
||||
)
|
||||
@ -1584,7 +1584,7 @@ class AttentionCombine(nn.Module):
|
||||
single_prob_mask)
|
||||
|
||||
weights = weights.masked_fill(mask, float('-inf'))
|
||||
weights = softmax(weights, dim=1)
|
||||
weights = weights.softmax(dim=1)
|
||||
|
||||
# (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1),
|
||||
ans = torch.matmul(stacked_inputs, weights.unsqueeze(2))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user