Bug fix
This commit is contained in:
parent
5fda800b6d
commit
3de8a5aef2
@ -1414,7 +1414,7 @@ class PoolingModule(nn.Module):
|
|||||||
a Tensor of shape (1, N, C)
|
a Tensor of shape (1, N, C)
|
||||||
"""
|
"""
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
pooling_mask = key_padding_mask.logical_not().to(src.dtype) # (N, T)
|
pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T)
|
||||||
pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True))
|
pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True))
|
||||||
pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1)
|
pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1)
|
||||||
# now pooling_mask: (T, N, 1)
|
# now pooling_mask: (T, N, 1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user