mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Rework positional encoding
This commit is contained in:
parent
e4a3b2da7d
commit
125ea04a42
@ -964,80 +964,89 @@ class SimpleCombiner(torch.nn.Module):
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module.
|
||||
|
||||
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
|
||||
|
||||
Args:
|
||||
embed_dim: Embedding dimension.
|
||||
max_offset: determines the largest offset that can be coded distinctly from smaller
|
||||
offsets. If you change this, the embedding will change.
|
||||
dropout_rate: Dropout rate.
|
||||
max_len: Maximum input length.
|
||||
|
||||
max_len: Maximum input length: just a heuristic for initialization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, embed_dim: int, dropout_rate: float, max_len: int = 5000
|
||||
self, embed_dim: int,
|
||||
dropout_rate: float,
|
||||
max_offset: float = 1000.0,
|
||||
max_len: int = 10
|
||||
) -> None:
|
||||
"""Construct a PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
self.max_offset = max_offset
|
||||
self.embed_dim = embed_dim
|
||||
assert embed_dim % 2 == 0
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
self.extend_pe(torch.tensor(0.0).expand(max_len))
|
||||
|
||||
def extend_pe(self, x: Tensor) -> None:
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(1) >= x.size(0) * 2 - 1:
|
||||
if self.pe.size(0) >= x.size(0) * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
):
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||
# position of key vector. We use position relative positions when keys
|
||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||
pe_positive = torch.zeros(x.size(0), self.embed_dim)
|
||||
pe_negative = torch.zeros(x.size(0), self.embed_dim)
|
||||
position = torch.arange(0, x.size(0), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.embed_dim, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.embed_dim)
|
||||
)
|
||||
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
||||
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
||||
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
||||
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
||||
|
||||
# Reserve the order of positive indices and concat both positive and
|
||||
# negative indices. This is used to support the shifting trick
|
||||
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
T = x.size(0)
|
||||
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
|
||||
x = torch.arange(-(T-1), T,
|
||||
device=x.device).to(torch.float32).unsqueeze(1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
||||
# e.g. ratio might be 1.1
|
||||
ratio = self.max_offset ** (2.0 / self.embed_dim)
|
||||
# e.g. width_factor == 0.1, if ratio is 1.1. determines steepness of sigmoid.
|
||||
width_factor = ratio - 1.0
|
||||
|
||||
max_val = 1000.0
|
||||
|
||||
# centers of sigmoids (positive)
|
||||
centers = (1.0 / (ratio - 1.0)) * (ratio ** torch.arange(self.embed_dim // 2,
|
||||
device=x.device))
|
||||
# adjust centers so the 1st value is 0.5, the 2nd value is about 1.5, and so on.
|
||||
centers = centers + (0.5 - centers[0])
|
||||
centers = centers.reshape(self.embed_dim // 2, 1).expand(self.embed_dim // 2, 2).contiguous()
|
||||
centers[:, 1] *= -1.0
|
||||
centers = centers.reshape(self.embed_dim)
|
||||
widths = (centers.abs() * width_factor).clamp(min=1.0)
|
||||
|
||||
# shape: (2*T - 1, embed_dim)
|
||||
pe = ((x - centers) / widths).sigmoid()
|
||||
|
||||
self.pe = pe.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tensor:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (time, batch, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
self.pe.size(0) // 2
|
||||
- x.size(0)
|
||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||
+ 1 : self.pe.size(0) // 2 # noqa E203
|
||||
+ x.size(0),
|
||||
:
|
||||
]
|
||||
pos_emb = pos_emb.unsqueeze(0) # now: (1, 2*time-1, embed_dim)
|
||||
return self.dropout(pos_emb)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user