Tweak formula for widths

This commit is contained in:
Daniel Povey 2022-11-10 13:04:56 +08:00
parent 6091146e91
commit fd26b890d2

View File

@ -1019,7 +1019,7 @@ class RelPositionalEncoding(torch.nn.Module):
centers[:, 1] *= -1.0
centers = centers.reshape(self.embed_dim)
widths = (centers.abs() * width_factor).clamp(min=1.0)
widths = (centers.abs() * width_factor) + 1.0
# shape: (2*T - 1, embed_dim)
pe = ((x - centers) / widths).tanh()