Rework positional encoding

This commit is contained in:
Daniel Povey 2022-11-09 20:48:27 +08:00
parent e4a3b2da7d
commit 125ea04a42

View File

@ -964,80 +964,89 @@ class SimpleCombiner(torch.nn.Module):
class RelPositionalEncoding(torch.nn.Module): class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module. """Relative positional encoding module.
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
Args: Args:
embed_dim: Embedding dimension. embed_dim: Embedding dimension.
max_offset: determines the largest offset that can be coded distinctly from smaller
offsets. If you change this, the embedding will change.
dropout_rate: Dropout rate. dropout_rate: Dropout rate.
max_len: Maximum input length. max_len: Maximum input length: just a heuristic for initialization.
""" """
def __init__( def __init__(
self, embed_dim: int, dropout_rate: float, max_len: int = 5000 self, embed_dim: int,
dropout_rate: float,
max_offset: float = 1000.0,
max_len: int = 10
) -> None: ) -> None:
"""Construct a PositionalEncoding object.""" """Construct a PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__() super(RelPositionalEncoding, self).__init__()
self.max_offset = max_offset
self.embed_dim = embed_dim self.embed_dim = embed_dim
assert embed_dim % 2 == 0
self.dropout = torch.nn.Dropout(dropout_rate) self.dropout = torch.nn.Dropout(dropout_rate)
self.pe = None self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len)) self.extend_pe(torch.tensor(0.0).expand(max_len))
def extend_pe(self, x: Tensor) -> None: def extend_pe(self, x: Tensor) -> None:
"""Reset the positional encodings.""" """Reset the positional encodings."""
if self.pe is not None: if self.pe is not None:
# self.pe contains both positive and negative parts # self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(0) * 2 - 1: if self.pe.size(0) >= x.size(0) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device # Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str( if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device x.device
): ):
self.pe = self.pe.to(dtype=x.dtype, device=x.device) self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(0), self.embed_dim)
pe_negative = torch.zeros(x.size(0), self.embed_dim)
position = torch.arange(0, x.size(0), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embed_dim, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.embed_dim)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and T = x.size(0)
# negative indices. This is used to support the shifting trick # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" x = torch.arange(-(T-1), T,
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) device=x.device).to(torch.float32).unsqueeze(1)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]: # e.g. ratio might be 1.1
ratio = self.max_offset ** (2.0 / self.embed_dim)
# e.g. width_factor == 0.1, if ratio is 1.1. determines steepness of sigmoid.
width_factor = ratio - 1.0
max_val = 1000.0
# centers of sigmoids (positive)
centers = (1.0 / (ratio - 1.0)) * (ratio ** torch.arange(self.embed_dim // 2,
device=x.device))
# adjust centers so the 1st value is 0.5, the 2nd value is about 1.5, and so on.
centers = centers + (0.5 - centers[0])
centers = centers.reshape(self.embed_dim // 2, 1).expand(self.embed_dim // 2, 2).contiguous()
centers[:, 1] *= -1.0
centers = centers.reshape(self.embed_dim)
widths = (centers.abs() * width_factor).clamp(min=1.0)
# shape: (2*T - 1, embed_dim)
pe = ((x - centers) / widths).sigmoid()
self.pe = pe.to(dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tensor:
"""Add positional encoding. """Add positional encoding.
Args: Args:
x (torch.Tensor): Input tensor (time, batch, `*`). x (torch.Tensor): Input tensor (time, batch, `*`).
Returns: Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
""" """
self.extend_pe(x) self.extend_pe(x)
pos_emb = self.pe[ pos_emb = self.pe[
:, self.pe.size(0) // 2
self.pe.size(1) // 2
- x.size(0) - x.size(0)
+ 1 : self.pe.size(1) // 2 # noqa E203 + 1 : self.pe.size(0) // 2 # noqa E203
+ x.size(0), + x.size(0),
:
] ]
pos_emb = pos_emb.unsqueeze(0) # now: (1, 2*time-1, embed_dim)
return self.dropout(pos_emb) return self.dropout(pos_emb)