New formula for pos emb

This commit is contained in:
Daniel Povey 2022-11-10 22:03:42 +08:00
parent fd26b890d2
commit 60274ea731

View File

@ -967,20 +967,16 @@ class RelPositionalEncoding(torch.nn.Module):
Args:
embed_dim: Embedding dimension.
max_offset: determines the largest offset that can be coded distinctly from smaller
offsets. If you change this, the embedding will change.
dropout_rate: Dropout rate.
max_len: Maximum input length: just a heuristic for initialization.
"""
def __init__(
self, embed_dim: int,
dropout_rate: float,
max_offset: float = 1000.0,
max_len: int = 10
max_len: int = 1000
) -> None:
"""Construct a PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.max_offset = max_offset
self.embed_dim = embed_dim
assert embed_dim % 2 == 0
self.dropout = torch.nn.Dropout(dropout_rate)
@ -1005,27 +1001,24 @@ class RelPositionalEncoding(torch.nn.Module):
x = torch.arange(-(T-1), T,
device=x.device).to(torch.float32).unsqueeze(1)
# e.g. ratio might be 1.1
ratio = self.max_offset ** (2.0 / self.embed_dim)
# e.g. width_factor == 0.1, if ratio is 1.1. determines steepness of sigmoid.
width_factor = ratio - 1.0
# centers of tanh functions (positive)
centers = (1.0 / (ratio - 1.0)) * (ratio ** torch.arange(self.embed_dim // 2,
device=x.device))
# adjust centers so the 1st value is 0.5, the 2nd value is about 1.5, and so on.
centers = centers + (0.5 - centers[0])
centers = centers.reshape(self.embed_dim // 2, 1).expand(self.embed_dim // 2, 2).contiguous()
centers[:, 1] *= -1.0
centers = centers.reshape(self.embed_dim)
length_factor = self.embed_dim / (2.0*math.pi) # todo: test this.
widths = (centers.abs() * width_factor) + 1.0
freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
x_tanh = (x / length_factor).tanh() # results between -pi and pi
cosines = (x_tanh * freqs).cos()
sines = (x_tanh * freqs).sin()
pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
pe[:, 0::2] = cosines
pe[:, 1::2] = sines
pe[:, -1] = 1.0 # for bias.
#print("cosines = ", cosines[T-r:T+r,:5])
#print("sines = ", sines[T-r:T+r,:5])
# shape: (2*T - 1, embed_dim)
pe = ((x - centers) / widths).tanh()
# Let the last dimension of the embedding be a constant 1.0, to provide a bias term
# so we don't have to give a bias to the projection.
pe[:, -1] = 1.0
self.pe = pe.to(dtype=x.dtype)