Avoid infinities in padding frames

This commit is contained in:
Daniel Povey 2022-12-20 19:19:45 +08:00
parent 494139d27a
commit 3b4b33af58

View File

@ -610,7 +610,8 @@ class ConvNorm1d(torch.nn.Module):
counts = counts.masked_fill_(src_key_padding_mask.unsqueeze(1), 0.0)
sqnorms = sqnorms * counts
sqnorms = self.conv(sqnorms)
counts = self.conv(counts)
# the clamping is to avoid division by zero for padding frames.
counts = self.conv(counts).clamp(min=0.01)
# scales: (N, 1, T)
scales = (sqnorms / counts + eps.exp()) ** -0.5
return x * scales