New formula for pos emb
This commit is contained in:
parent
fd26b890d2
commit
60274ea731
@ -967,20 +967,16 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
|
||||
Args:
|
||||
embed_dim: Embedding dimension.
|
||||
max_offset: determines the largest offset that can be coded distinctly from smaller
|
||||
offsets. If you change this, the embedding will change.
|
||||
dropout_rate: Dropout rate.
|
||||
max_len: Maximum input length: just a heuristic for initialization.
|
||||
"""
|
||||
def __init__(
|
||||
self, embed_dim: int,
|
||||
dropout_rate: float,
|
||||
max_offset: float = 1000.0,
|
||||
max_len: int = 10
|
||||
max_len: int = 1000
|
||||
) -> None:
|
||||
"""Construct a PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
self.max_offset = max_offset
|
||||
self.embed_dim = embed_dim
|
||||
assert embed_dim % 2 == 0
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
@ -1005,27 +1001,24 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
x = torch.arange(-(T-1), T,
|
||||
device=x.device).to(torch.float32).unsqueeze(1)
|
||||
|
||||
# e.g. ratio might be 1.1
|
||||
ratio = self.max_offset ** (2.0 / self.embed_dim)
|
||||
# e.g. width_factor == 0.1, if ratio is 1.1. determines steepness of sigmoid.
|
||||
width_factor = ratio - 1.0
|
||||
|
||||
# centers of tanh functions (positive)
|
||||
centers = (1.0 / (ratio - 1.0)) * (ratio ** torch.arange(self.embed_dim // 2,
|
||||
device=x.device))
|
||||
# adjust centers so the 1st value is 0.5, the 2nd value is about 1.5, and so on.
|
||||
centers = centers + (0.5 - centers[0])
|
||||
centers = centers.reshape(self.embed_dim // 2, 1).expand(self.embed_dim // 2, 2).contiguous()
|
||||
centers[:, 1] *= -1.0
|
||||
centers = centers.reshape(self.embed_dim)
|
||||
length_factor = self.embed_dim / (2.0*math.pi) # todo: test this.
|
||||
|
||||
widths = (centers.abs() * width_factor) + 1.0
|
||||
freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
|
||||
|
||||
x_tanh = (x / length_factor).tanh() # results between -pi and pi
|
||||
|
||||
cosines = (x_tanh * freqs).cos()
|
||||
sines = (x_tanh * freqs).sin()
|
||||
|
||||
pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
|
||||
pe[:, 0::2] = cosines
|
||||
pe[:, 1::2] = sines
|
||||
pe[:, -1] = 1.0 # for bias.
|
||||
|
||||
#print("cosines = ", cosines[T-r:T+r,:5])
|
||||
#print("sines = ", sines[T-r:T+r,:5])
|
||||
|
||||
# shape: (2*T - 1, embed_dim)
|
||||
pe = ((x - centers) / widths).tanh()
|
||||
# Let the last dimension of the embedding be a constant 1.0, to provide a bias term
|
||||
# so we don't have to give a bias to the projection.
|
||||
pe[:, -1] = 1.0
|
||||
|
||||
self.pe = pe.to(dtype=x.dtype)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user