This commit is contained in:
Daniel Povey 2022-10-31 15:50:46 +08:00
parent 5fda800b6d
commit 3de8a5aef2

View File

@ -1414,7 +1414,7 @@ class PoolingModule(nn.Module):
a Tensor of shape (1, N, C)
"""
if key_padding_mask is not None:
pooling_mask = key_padding_mask.logical_not().to(src.dtype) # (N, T)
pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T)
pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True))
pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1)
# now pooling_mask: (T, N, 1)