mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
fixes
This commit is contained in:
parent
3b4fa4863f
commit
755430c29e
@ -407,14 +407,18 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# (batch, max_len)
|
||||
|
||||
key_padding_mask = (
|
||||
(
|
||||
key_padding_mask.unsqueeze(0)
|
||||
.repeat(1, self.prune_range, 1)
|
||||
.unsqueeze(2)
|
||||
key_padding_mask.unsqueeze(1)
|
||||
.expand(
|
||||
key_padding_mask.shape[0], # b
|
||||
self.prune_range,
|
||||
key_padding_mask.shape[1], # l
|
||||
)
|
||||
if key_padding_mask.shape[0] != attn_scores.shape[1]
|
||||
else key_padding_mask.unsqueeze(0).unsqueeze(2)
|
||||
.reshape(b_p_dim, am_seq_len)
|
||||
.unsqueeze(1)
|
||||
.unsqueeze(0)
|
||||
)
|
||||
# print(key_padding_mask.shape)
|
||||
# (1, b * p, 1, T)
|
||||
|
||||
attn_scores = attn_scores.masked_fill(
|
||||
key_padding_mask,
|
||||
|
Loading…
x
Reference in New Issue
Block a user