Use atan not tanh

This commit is contained in:
Daniel Povey 2022-11-10 22:32:38 +08:00
parent 60274ea731
commit 2c6f5e82b2

View File

@ -1002,22 +1002,24 @@ class RelPositionalEncoding(torch.nn.Module):
device=x.device).to(torch.float32).unsqueeze(1)
length_factor = self.embed_dim / (2.0*math.pi) # todo: test this.
length_factor = self.embed_dim / (2.0 * math.pi) # todo: test this.
freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
x_tanh = (x / length_factor).tanh() # results between -pi and pi
x_atan = (x / length_factor).atan() # results between -pi and pi
cosines = (x_tanh * freqs).cos()
sines = (x_tanh * freqs).sin()
cosines = (x_atan * freqs).cos()
sines = (x_atan * freqs).sin()
pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
pe[:, 0::2] = cosines
pe[:, 1::2] = sines
pe[:, -1] = 1.0 # for bias.
#print("cosines = ", cosines[T-r:T+r,:5])
#print("sines = ", sines[T-r:T+r,:5])
#r = 2
#print("cosines = ", cosines[T-r:T+r,-5:])
#print("sines = ", sines[T-r:T+r,-5:])
self.pe = pe.to(dtype=x.dtype)