mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use atan not tanh
This commit is contained in:
parent
60274ea731
commit
2c6f5e82b2
@ -1002,22 +1002,24 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
device=x.device).to(torch.float32).unsqueeze(1)
|
||||
|
||||
|
||||
length_factor = self.embed_dim / (2.0*math.pi) # todo: test this.
|
||||
length_factor = self.embed_dim / (2.0 * math.pi) # todo: test this.
|
||||
|
||||
freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
|
||||
|
||||
x_tanh = (x / length_factor).tanh() # results between -pi and pi
|
||||
x_atan = (x / length_factor).atan() # results between -pi and pi
|
||||
|
||||
cosines = (x_tanh * freqs).cos()
|
||||
sines = (x_tanh * freqs).sin()
|
||||
|
||||
cosines = (x_atan * freqs).cos()
|
||||
sines = (x_atan * freqs).sin()
|
||||
|
||||
pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
|
||||
pe[:, 0::2] = cosines
|
||||
pe[:, 1::2] = sines
|
||||
pe[:, -1] = 1.0 # for bias.
|
||||
|
||||
#print("cosines = ", cosines[T-r:T+r,:5])
|
||||
#print("sines = ", sines[T-r:T+r,:5])
|
||||
#r = 2
|
||||
#print("cosines = ", cosines[T-r:T+r,-5:])
|
||||
#print("sines = ", sines[T-r:T+r,-5:])
|
||||
|
||||
|
||||
self.pe = pe.to(dtype=x.dtype)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user