mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Change the formula for the embedding to be a bit more symmetric.
This commit is contained in:
parent
082b93d911
commit
6091146e91
@ -1010,7 +1010,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
# e.g. width_factor == 0.1, if ratio is 1.1. determines steepness of sigmoid.
|
||||
width_factor = ratio - 1.0
|
||||
|
||||
# centers of sigmoids (positive)
|
||||
# centers of tanh functions (positive)
|
||||
centers = (1.0 / (ratio - 1.0)) * (ratio ** torch.arange(self.embed_dim // 2,
|
||||
device=x.device))
|
||||
# adjust centers so the 1st value is 0.5, the 2nd value is about 1.5, and so on.
|
||||
@ -1018,10 +1018,14 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
centers = centers.reshape(self.embed_dim // 2, 1).expand(self.embed_dim // 2, 2).contiguous()
|
||||
centers[:, 1] *= -1.0
|
||||
centers = centers.reshape(self.embed_dim)
|
||||
|
||||
widths = (centers.abs() * width_factor).clamp(min=1.0)
|
||||
|
||||
# shape: (2*T - 1, embed_dim)
|
||||
pe = ((x - centers) / widths).sigmoid()
|
||||
pe = ((x - centers) / widths).tanh()
|
||||
# Let the last dimension of the embedding be a constant 1.0, to provide a bias term
|
||||
# so we don't have to give a bias to the projection.
|
||||
pe[:, -1] = 1.0
|
||||
|
||||
self.pe = pe.to(dtype=x.dtype)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user