mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Rework positional encoding
This commit is contained in:
parent
e4a3b2da7d
commit
125ea04a42
@ -964,80 +964,89 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
class RelPositionalEncoding(torch.nn.Module):
|
class RelPositionalEncoding(torch.nn.Module):
|
||||||
"""Relative positional encoding module.
|
"""Relative positional encoding module.
|
||||||
|
|
||||||
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
|
||||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embed_dim: Embedding dimension.
|
embed_dim: Embedding dimension.
|
||||||
|
max_offset: determines the largest offset that can be coded distinctly from smaller
|
||||||
|
offsets. If you change this, the embedding will change.
|
||||||
dropout_rate: Dropout rate.
|
dropout_rate: Dropout rate.
|
||||||
max_len: Maximum input length.
|
max_len: Maximum input length: just a heuristic for initialization.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, embed_dim: int, dropout_rate: float, max_len: int = 5000
|
self, embed_dim: int,
|
||||||
|
dropout_rate: float,
|
||||||
|
max_offset: float = 1000.0,
|
||||||
|
max_len: int = 10
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct a PositionalEncoding object."""
|
"""Construct a PositionalEncoding object."""
|
||||||
super(RelPositionalEncoding, self).__init__()
|
super(RelPositionalEncoding, self).__init__()
|
||||||
|
self.max_offset = max_offset
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||||
self.pe = None
|
self.pe = None
|
||||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
self.extend_pe(torch.tensor(0.0).expand(max_len))
|
||||||
|
|
||||||
def extend_pe(self, x: Tensor) -> None:
|
def extend_pe(self, x: Tensor) -> None:
|
||||||
"""Reset the positional encodings."""
|
"""Reset the positional encodings."""
|
||||||
if self.pe is not None:
|
if self.pe is not None:
|
||||||
# self.pe contains both positive and negative parts
|
# self.pe contains both positive and negative parts
|
||||||
# the length of self.pe is 2 * input_len - 1
|
# the length of self.pe is 2 * input_len - 1
|
||||||
if self.pe.size(1) >= x.size(0) * 2 - 1:
|
if self.pe.size(0) >= x.size(0) * 2 - 1:
|
||||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||||
x.device
|
x.device
|
||||||
):
|
):
|
||||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||||
return
|
return
|
||||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
|
||||||
# position of key vector. We use position relative positions when keys
|
|
||||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
|
||||||
pe_positive = torch.zeros(x.size(0), self.embed_dim)
|
|
||||||
pe_negative = torch.zeros(x.size(0), self.embed_dim)
|
|
||||||
position = torch.arange(0, x.size(0), dtype=torch.float32).unsqueeze(1)
|
|
||||||
div_term = torch.exp(
|
|
||||||
torch.arange(0, self.embed_dim, 2, dtype=torch.float32)
|
|
||||||
* -(math.log(10000.0) / self.embed_dim)
|
|
||||||
)
|
|
||||||
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
|
||||||
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
|
||||||
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
|
||||||
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
|
||||||
|
|
||||||
# Reserve the order of positive indices and concat both positive and
|
T = x.size(0)
|
||||||
# negative indices. This is used to support the shifting trick
|
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
|
||||||
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
x = torch.arange(-(T-1), T,
|
||||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
device=x.device).to(torch.float32).unsqueeze(1)
|
||||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
|
||||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
|
||||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
# e.g. ratio might be 1.1
|
||||||
|
ratio = self.max_offset ** (2.0 / self.embed_dim)
|
||||||
|
# e.g. width_factor == 0.1, if ratio is 1.1. determines steepness of sigmoid.
|
||||||
|
width_factor = ratio - 1.0
|
||||||
|
|
||||||
|
max_val = 1000.0
|
||||||
|
|
||||||
|
# centers of sigmoids (positive)
|
||||||
|
centers = (1.0 / (ratio - 1.0)) * (ratio ** torch.arange(self.embed_dim // 2,
|
||||||
|
device=x.device))
|
||||||
|
# adjust centers so the 1st value is 0.5, the 2nd value is about 1.5, and so on.
|
||||||
|
centers = centers + (0.5 - centers[0])
|
||||||
|
centers = centers.reshape(self.embed_dim // 2, 1).expand(self.embed_dim // 2, 2).contiguous()
|
||||||
|
centers[:, 1] *= -1.0
|
||||||
|
centers = centers.reshape(self.embed_dim)
|
||||||
|
widths = (centers.abs() * width_factor).clamp(min=1.0)
|
||||||
|
|
||||||
|
# shape: (2*T - 1, embed_dim)
|
||||||
|
pe = ((x - centers) / widths).sigmoid()
|
||||||
|
|
||||||
|
self.pe = pe.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> Tensor:
|
||||||
"""Add positional encoding.
|
"""Add positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): Input tensor (time, batch, `*`).
|
x (torch.Tensor): Input tensor (time, batch, `*`).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
|
||||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x)
|
self.extend_pe(x)
|
||||||
pos_emb = self.pe[
|
pos_emb = self.pe[
|
||||||
:,
|
self.pe.size(0) // 2
|
||||||
self.pe.size(1) // 2
|
|
||||||
- x.size(0)
|
- x.size(0)
|
||||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
+ 1 : self.pe.size(0) // 2 # noqa E203
|
||||||
+ x.size(0),
|
+ x.size(0),
|
||||||
|
:
|
||||||
]
|
]
|
||||||
|
pos_emb = pos_emb.unsqueeze(0) # now: (1, 2*time-1, embed_dim)
|
||||||
return self.dropout(pos_emb)
|
return self.dropout(pos_emb)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user