Bug fix in formula for pos embedding

This commit is contained in:
Daniel Povey 2022-11-17 16:02:57 +08:00
parent 526b5e59a6
commit e73ced1607

View File

@ -932,8 +932,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
# but it does so more slowly than T for large absolute values of T.
# The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which
# is important.
x_compressed = compression_length * x.sign() * (x.abs() + compression_length).log()
x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length))
# length_factor is chosen so that the FFT can exactly separate points
# close to the origin (T == 0). So this part of the formulation is not really