mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Bug fix
This commit is contained in:
parent
5fda800b6d
commit
3de8a5aef2
@ -1414,7 +1414,7 @@ class PoolingModule(nn.Module):
|
||||
a Tensor of shape (1, N, C)
|
||||
"""
|
||||
if key_padding_mask is not None:
|
||||
pooling_mask = key_padding_mask.logical_not().to(src.dtype) # (N, T)
|
||||
pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T)
|
||||
pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True))
|
||||
pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1)
|
||||
# now pooling_mask: (T, N, 1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user