mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
fixes
This commit is contained in:
parent
3b4fa4863f
commit
755430c29e
@ -407,14 +407,18 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
# (batch, max_len)
|
# (batch, max_len)
|
||||||
|
|
||||||
key_padding_mask = (
|
key_padding_mask = (
|
||||||
(
|
key_padding_mask.unsqueeze(1)
|
||||||
key_padding_mask.unsqueeze(0)
|
.expand(
|
||||||
.repeat(1, self.prune_range, 1)
|
key_padding_mask.shape[0], # b
|
||||||
.unsqueeze(2)
|
self.prune_range,
|
||||||
|
key_padding_mask.shape[1], # l
|
||||||
)
|
)
|
||||||
if key_padding_mask.shape[0] != attn_scores.shape[1]
|
.reshape(b_p_dim, am_seq_len)
|
||||||
else key_padding_mask.unsqueeze(0).unsqueeze(2)
|
.unsqueeze(1)
|
||||||
|
.unsqueeze(0)
|
||||||
)
|
)
|
||||||
|
# print(key_padding_mask.shape)
|
||||||
|
# (1, b * p, 1, T)
|
||||||
|
|
||||||
attn_scores = attn_scores.masked_fill(
|
attn_scores = attn_scores.masked_fill(
|
||||||
key_padding_mask,
|
key_padding_mask,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user