diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 7e8bea503..aa6e3d1ae 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1414,7 +1414,7 @@ class PoolingModule(nn.Module): a Tensor of shape (1, N, C) """ if key_padding_mask is not None: - pooling_mask = key_padding_mask.logical_not().to(src.dtype) # (N, T) + pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True)) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) # now pooling_mask: (T, N, 1)