mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
minor fixes
This commit is contained in:
parent
372f63ae7f
commit
9d2ac7b1ec
@ -374,9 +374,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
pos_scores = pos_scores.reshape(-1, n)
|
||||
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||
pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len)
|
||||
else:
|
||||
elif not for_reference:
|
||||
pos_scores = pos_scores.as_strided(
|
||||
(num_heads, b_p_dim, lm_seq_len, lm_seq_len),
|
||||
(num_heads, b_p_dim, lm_seq_len, am_seq_len),
|
||||
(
|
||||
pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
|
Loading…
x
Reference in New Issue
Block a user