minor fixes

This commit is contained in:
JinZr 2023-08-14 21:05:56 +08:00
parent 372f63ae7f
commit 9d2ac7b1ec

View File

@ -374,9 +374,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
pos_scores = pos_scores.reshape(-1, n) pos_scores = pos_scores.reshape(-1, n)
pos_scores = torch.gather(pos_scores, dim=1, index=indexes) pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len) pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len)
else: elif not for_reference:
pos_scores = pos_scores.as_strided( pos_scores = pos_scores.as_strided(
(num_heads, b_p_dim, lm_seq_len, lm_seq_len), (num_heads, b_p_dim, lm_seq_len, am_seq_len),
( (
pos_scores.stride(0), pos_scores.stride(0),
pos_scores.stride(1), pos_scores.stride(1),