Use compression of large x in the formula for pos_emb

This commit is contained in:
Daniel Povey 2022-11-13 13:23:42 +08:00
parent 6c16d08b4f
commit 463fed3d6a

View File

@ -1008,14 +1008,28 @@ class RelPositionalEncoding(torch.nn.Module):
x = torch.arange(-(T-1), T, x = torch.arange(-(T-1), T,
device=x.device).to(torch.float32).unsqueeze(1) device=x.device).to(torch.float32).unsqueeze(1)
length_factor = self.embed_dim / (2.0 * math.pi) # todo: test this.
freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
x_atan = (x / length_factor).atan() # results between -pi and pi # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution
# for small time offsets but less resolution for large time offsets.
compression_length = (self.embed_dim ** 0.5)
# x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity;
# but it does so more slowly than T for large absolute values of T.
# The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which
# is important.
x_compressed = compression_length * x.sign() * (x.abs() + compression_length).log()
# length_factor is chosen so that the FFT can exactly separate points
# close to the origin (T == 0). So this part of the formulation is not really
# heuristic.
length_factor = self.embed_dim / (2.0 * math.pi) # todo: test this.
# note for machine implementations: if atan is not available, we can use:
# x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
# check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
x_atan = (x_compressed / length_factor).atan() # results between -pi and pi
cosines = (x_atan * freqs).cos() cosines = (x_atan * freqs).cos()
sines = (x_atan * freqs).sin() sines = (x_atan * freqs).sin()