diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 9e2346922..a84bc3d95 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -1171,7 +1171,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz * num_heads, seq_len, seq_len ) - attn_output_weights = softmax(attn_output_weights, dim=-1) + attn_output_weights = attn_output_weights.softmax(dim=-1) attn_output_weights = nn.functional.dropout( attn_output_weights, p=dropout_p, training=training ) @@ -1584,7 +1584,7 @@ class AttentionCombine(nn.Module): single_prob_mask) weights = weights.masked_fill(mask, float('-inf')) - weights = softmax(weights, dim=1) + weights = weights.softmax(dim=1) # (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1), ans = torch.matmul(stacked_inputs, weights.unsqueeze(2))