mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
support positional encoding for conv-emformer
This commit is contained in:
parent
50fe100f50
commit
8e6a51edaa
@ -80,13 +80,22 @@ class EmformerAttention(nn.Module):
|
||||
self.nhead = nhead
|
||||
self.tanh_on_mem = tanh_on_mem
|
||||
self.negative_inf = negative_inf
|
||||
self.head_dim = embed_dim // nhead
|
||||
|
||||
self.scaling = (self.embed_dim // self.nhead) ** -0.5
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.emb_to_key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
|
||||
self.emb_to_query = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(nhead, self.head_dim))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(nhead, self.head_dim))
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self) -> None:
|
||||
@ -99,6 +108,11 @@ class EmformerAttention(nn.Module):
|
||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
|
||||
nn.init.xavier_uniform_(self.linear_pos.weight)
|
||||
|
||||
nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def _gen_attention_probs(
|
||||
self,
|
||||
attention_weights: torch.Tensor,
|
||||
@ -152,6 +166,32 @@ class EmformerAttention(nn.Module):
|
||||
|
||||
return attention_probs
|
||||
|
||||
def _rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x: Input tensor, of shape (B, nhead, U, PE).
|
||||
U is the length of query vector.
|
||||
For non-infer mode, PE = 2 * U - 1;
|
||||
for infer mode, PE = L + 2 * U - 1.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (B, nhead, U, out_len).
|
||||
For non-infer mode, out_len = U;
|
||||
for infer mode, out_len = L + U.
|
||||
"""
|
||||
B, nhead, U, PE = x.size()
|
||||
B_stride = x.stride(0)
|
||||
nhead_stride = x.stride(1)
|
||||
U_stride = x.stride(2)
|
||||
PE_stride = x.stride(3)
|
||||
out_len = PE - (U - 1)
|
||||
return x.as_strided(
|
||||
size=(B, nhead, U, out_len),
|
||||
stride=(B_stride, nhead_stride, U_stride - PE_stride, PE_stride),
|
||||
storage_offset=PE_stride * (U - 1),
|
||||
)
|
||||
|
||||
def _forward_impl(
|
||||
self,
|
||||
utterance: torch.Tensor,
|
||||
@ -160,6 +200,7 @@ class EmformerAttention(nn.Module):
|
||||
summary: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
left_context_key: Optional[torch.Tensor] = None,
|
||||
left_context_val: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
@ -194,6 +235,10 @@ class EmformerAttention(nn.Module):
|
||||
Memory elements, with shape (M, B, D).
|
||||
attention_mask (torch.Tensor):
|
||||
Attention mask for underlying attention, with shape (Q, KV).
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For training mode, PE = 2*U-1;
|
||||
For infer mode, PE = L+2*U-1.
|
||||
left_context_key (torch,Tensor, optional):
|
||||
Cached attention key of left context from preceding computation,
|
||||
with shape (L, B, D).
|
||||
@ -208,7 +253,9 @@ class EmformerAttention(nn.Module):
|
||||
- attention key, with shape (KV, B, D).
|
||||
- attention value, with shape (KV, B, D).
|
||||
"""
|
||||
B = utterance.size(1)
|
||||
U, B, _ = utterance.size()
|
||||
R = right_context.size(0)
|
||||
M = memory.size(0)
|
||||
|
||||
# Compute query with [right context, utterance, summary].
|
||||
query = self.emb_to_query(
|
||||
@ -222,41 +269,71 @@ class EmformerAttention(nn.Module):
|
||||
if left_context_key is not None and left_context_val is not None:
|
||||
# This is for inference mode. Now compute key and value with
|
||||
# [mems, right context, left context, uttrance]
|
||||
M = memory.size(0)
|
||||
R = right_context.size(0)
|
||||
right_context_end_idx = M + R
|
||||
key = torch.cat(
|
||||
[
|
||||
key[:right_context_end_idx],
|
||||
left_context_key,
|
||||
key[right_context_end_idx:],
|
||||
]
|
||||
[key[: M + R], left_context_key, key[M + R :]] # noqa
|
||||
)
|
||||
value = torch.cat(
|
||||
[
|
||||
value[:right_context_end_idx],
|
||||
left_context_val,
|
||||
value[right_context_end_idx:],
|
||||
]
|
||||
[value[: M + R], left_context_val, value[M + R :]] # noqa
|
||||
)
|
||||
Q = query.size(0)
|
||||
KV = key.size(0)
|
||||
|
||||
# Compute attention weights from query, key, and value.
|
||||
reshaped_query, reshaped_key, reshaped_value = [
|
||||
reshaped_key, reshaped_value = [
|
||||
tensor.contiguous()
|
||||
.view(-1, B * self.nhead, self.embed_dim // self.nhead)
|
||||
.view(KV, B * self.nhead, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
for tensor in [query, key, value]
|
||||
]
|
||||
attention_weights = torch.bmm(
|
||||
reshaped_query * self.scaling, reshaped_key.transpose(1, 2)
|
||||
for tensor in [key, value]
|
||||
] # (B * nhead, KV, head_dim)
|
||||
reshaped_query = query.contiguous().view(
|
||||
Q, B, self.nhead, self.head_dim
|
||||
)
|
||||
|
||||
# compute attention matrix ac
|
||||
query_with_bais_u = (
|
||||
(reshaped_query + self.pos_bias_u)
|
||||
.view(Q, B * self.nhead, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
matrix_ac = torch.bmm(
|
||||
query_with_bais_u, reshaped_key.transpose(1, 2)
|
||||
) # (B * nhead, Q, KV)
|
||||
|
||||
# compute attention matrix bd
|
||||
utterance_with_bais_v = (
|
||||
reshaped_query[R : R + U] + self.pos_bias_v
|
||||
).permute(1, 2, 0, 3)
|
||||
# (B, nhead, U, head_dim)
|
||||
PE = pos_emb.size(0)
|
||||
if left_context_key is not None and left_context_val is not None:
|
||||
L = left_context_key.size(0)
|
||||
assert PE == L + 2 * U - 1
|
||||
else:
|
||||
assert PE == 2 * U - 1
|
||||
pos_emb = (
|
||||
self.linear_pos(pos_emb)
|
||||
.view(PE, self.nhead, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
.unsqueeze(0)
|
||||
) # (1, nhead, PE, head_dim)
|
||||
matrix_bd_utterance = torch.matmul(
|
||||
utterance_with_bais_v, pos_emb.transpose(-2, -1)
|
||||
) # (B, nhead, U, PE)
|
||||
# rel-shift
|
||||
matrix_bd_utterance = self._rel_shift(
|
||||
matrix_bd_utterance
|
||||
) # (B, nhead, U, U or L + U)
|
||||
matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
|
||||
B * self.nhead, U, -1
|
||||
)
|
||||
matrix_bd = torch.zeros_like(matrix_ac)
|
||||
matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance
|
||||
|
||||
attention_weights = (matrix_ac + matrix_bd) * self.scaling
|
||||
|
||||
# Compute padding mask
|
||||
if B == 1:
|
||||
padding_mask = None
|
||||
else:
|
||||
KV = key.size(0)
|
||||
U = utterance.size(0)
|
||||
padding_mask = make_pad_mask(KV - U + lengths)
|
||||
|
||||
# Compute attention probabilities.
|
||||
@ -266,12 +343,7 @@ class EmformerAttention(nn.Module):
|
||||
|
||||
# Compute attention.
|
||||
attention = torch.bmm(attention_probs, reshaped_value)
|
||||
Q = query.size(0)
|
||||
assert attention.shape == (
|
||||
B * self.nhead,
|
||||
Q,
|
||||
self.embed_dim // self.nhead,
|
||||
)
|
||||
assert attention.shape == (B * self.nhead, Q, self.head_dim)
|
||||
attention = (
|
||||
attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim)
|
||||
)
|
||||
@ -279,10 +351,8 @@ class EmformerAttention(nn.Module):
|
||||
# Apply output projection.
|
||||
outputs = self.out_proj(attention)
|
||||
|
||||
S = summary.size(0)
|
||||
summary_start_idx = Q - S
|
||||
output_right_context_utterance = outputs[:summary_start_idx]
|
||||
output_memory = outputs[summary_start_idx:]
|
||||
output_right_context_utterance = outputs[: R + U]
|
||||
output_memory = outputs[R + U :]
|
||||
if self.tanh_on_mem:
|
||||
output_memory = torch.tanh(output_memory)
|
||||
else:
|
||||
@ -298,6 +368,7 @@ class EmformerAttention(nn.Module):
|
||||
summary: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# TODO: Modify docs.
|
||||
"""Forward pass for training.
|
||||
@ -324,6 +395,9 @@ class EmformerAttention(nn.Module):
|
||||
attention_mask (torch.Tensor):
|
||||
Attention mask for underlying chunk-wise attention,
|
||||
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For training mode, P = 2*U-1.
|
||||
|
||||
Returns:
|
||||
A tuple containing 2 tensors:
|
||||
@ -336,7 +410,13 @@ class EmformerAttention(nn.Module):
|
||||
_,
|
||||
_,
|
||||
) = self._forward_impl(
|
||||
utterance, lengths, right_context, summary, memory, attention_mask
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
summary,
|
||||
memory,
|
||||
attention_mask,
|
||||
pos_emb,
|
||||
)
|
||||
return output_right_context_utterance, output_memory[:-1]
|
||||
|
||||
@ -350,6 +430,7 @@ class EmformerAttention(nn.Module):
|
||||
memory: torch.Tensor,
|
||||
left_context_key: torch.Tensor,
|
||||
left_context_val: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass for inference.
|
||||
|
||||
@ -379,6 +460,9 @@ class EmformerAttention(nn.Module):
|
||||
left_context_val (torch.Tensor):
|
||||
Cached attention value of left context from preceding computation,
|
||||
with shape (L, B, D).
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For infer mode, PE = L+2*U-1.
|
||||
|
||||
Returns:
|
||||
A tuple containing 4 tensors:
|
||||
@ -394,9 +478,9 @@ class EmformerAttention(nn.Module):
|
||||
# key, value: [memory, right context, left context, uttrance]
|
||||
KV = (
|
||||
memory.size(0)
|
||||
+ right_context.size(0)
|
||||
+ left_context_key.size(0)
|
||||
+ utterance.size(0)
|
||||
+ right_context.size(0) # noqa
|
||||
+ left_context_key.size(0) # noqa
|
||||
+ utterance.size(0) # noqa
|
||||
)
|
||||
attention_mask = torch.zeros(Q, KV).to(
|
||||
dtype=torch.bool, device=utterance.device
|
||||
@ -415,6 +499,7 @@ class EmformerAttention(nn.Module):
|
||||
summary,
|
||||
memory,
|
||||
attention_mask,
|
||||
pos_emb,
|
||||
left_context_key=left_context_key,
|
||||
left_context_val=left_context_val,
|
||||
)
|
||||
@ -643,6 +728,7 @@ class EmformerLayer(nn.Module):
|
||||
right_context_end_idx: int,
|
||||
lengths: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply attention module in non-infer mode."""
|
||||
@ -671,6 +757,7 @@ class EmformerLayer(nn.Module):
|
||||
summary=summary,
|
||||
memory=memory,
|
||||
attention_mask=attention_mask,
|
||||
pos_emb=pos_emb,
|
||||
)
|
||||
right_context_utterance = residual + self.dropout(
|
||||
output_right_context_utterance
|
||||
@ -684,6 +771,7 @@ class EmformerLayer(nn.Module):
|
||||
right_context_end_idx: int,
|
||||
lengths: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
state: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
"""Apply attention in infer mode.
|
||||
@ -717,6 +805,14 @@ class EmformerLayer(nn.Module):
|
||||
summary = torch.empty(0).to(
|
||||
dtype=utterance.dtype, device=utterance.device
|
||||
)
|
||||
# pos_emb is of shape [PE, D], PE = L + 2 * U - 1,
|
||||
# the relative distance j - i of key(j) and query(i) is in range of [-(L + U - 1), (U - 1)] # noqa
|
||||
L = left_context_key.size(0) # L <= left_context_length
|
||||
U = utterance.size(0)
|
||||
PE = L + 2 * U - 1
|
||||
tot_PE = self.left_context_length + 2 * U - 1
|
||||
assert pos_emb.size(0) == tot_PE
|
||||
pos_emb = pos_emb[tot_PE - PE :]
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
@ -730,6 +826,7 @@ class EmformerLayer(nn.Module):
|
||||
memory=pre_memory,
|
||||
left_context_key=left_context_key,
|
||||
left_context_val=left_context_val,
|
||||
pos_emb=pos_emb,
|
||||
)
|
||||
right_context_utterance = residual + self.dropout(
|
||||
output_right_context_utterance
|
||||
@ -746,6 +843,7 @@ class EmformerLayer(nn.Module):
|
||||
right_context: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
r"""Forward pass for training.
|
||||
1) Apply layer normalization on input utterance and right context
|
||||
@ -774,6 +872,9 @@ class EmformerLayer(nn.Module):
|
||||
attention_mask (torch.Tensor):
|
||||
Attention mask for underlying attention module,
|
||||
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For training mode, P = 2*U-1.
|
||||
|
||||
Returns:
|
||||
A tuple containing 3 tensors:
|
||||
@ -797,6 +898,7 @@ class EmformerLayer(nn.Module):
|
||||
lengths,
|
||||
memory,
|
||||
attention_mask,
|
||||
pos_emb,
|
||||
)
|
||||
|
||||
right_context_utterance = self._apply_conv_module_forward(
|
||||
@ -820,6 +922,7 @@ class EmformerLayer(nn.Module):
|
||||
lengths: torch.Tensor,
|
||||
right_context: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
state: Optional[List[torch.Tensor]] = None,
|
||||
conv_cache: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
||||
@ -851,6 +954,9 @@ class EmformerLayer(nn.Module):
|
||||
state (List[torch.Tensor], optional):
|
||||
List of tensors representing layer internal state generated in
|
||||
preceding computation. (default=None)
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For infer mode, PE = L+2*U-1.
|
||||
conv_cache (torch.Tensor, optional):
|
||||
Cache tensor of left context for causal convolution.
|
||||
|
||||
@ -878,6 +984,7 @@ class EmformerLayer(nn.Module):
|
||||
right_context_end_idx,
|
||||
lengths,
|
||||
memory,
|
||||
pos_emb,
|
||||
state,
|
||||
)
|
||||
|
||||
@ -1124,7 +1231,10 @@ class EmformerEncoder(nn.Module):
|
||||
return attention_mask
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, lengths: torch.Tensor
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass for training and non-streaming inference.
|
||||
|
||||
@ -1140,6 +1250,9 @@ class EmformerEncoder(nn.Module):
|
||||
With shape (B,) and i-th element representing number of valid
|
||||
utterance frames for i-th batch element in x, which contains the
|
||||
right_context at the end.
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For training mode, P = 2*U-1.
|
||||
|
||||
Returns:
|
||||
A tuple of 2 tensors:
|
||||
@ -1161,7 +1274,12 @@ class EmformerEncoder(nn.Module):
|
||||
output = utterance
|
||||
for layer in self.emformer_layers:
|
||||
output, right_context, memory = layer(
|
||||
output, output_lengths, right_context, memory, attention_mask
|
||||
output,
|
||||
output_lengths,
|
||||
right_context,
|
||||
memory,
|
||||
attention_mask,
|
||||
pos_emb,
|
||||
)
|
||||
|
||||
return output, output_lengths
|
||||
@ -1171,6 +1289,7 @@ class EmformerEncoder(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
states: Optional[List[List[torch.Tensor]]] = None,
|
||||
conv_caches: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[
|
||||
@ -1190,6 +1309,9 @@ class EmformerEncoder(nn.Module):
|
||||
With shape (B,) and i-th element representing number of valid
|
||||
utterance frames for i-th batch element in x, which contains the
|
||||
right_context at the end.
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For infer mode, PE = L+2*U-1.
|
||||
states (List[List[torch.Tensor]], optional):
|
||||
Cached states from proceeding chunk's computation, where each
|
||||
element (List[torch.Tensor]) corresponds to each emformer layer.
|
||||
@ -1234,6 +1356,7 @@ class EmformerEncoder(nn.Module):
|
||||
output_lengths,
|
||||
right_context,
|
||||
memory,
|
||||
pos_emb,
|
||||
None if states is None else states[layer_idx],
|
||||
None if conv_caches is None else conv_caches[layer_idx],
|
||||
)
|
||||
@ -1291,6 +1414,8 @@ class Emformer(EncoderInterface):
|
||||
else:
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
self.encoder = EmformerEncoder(
|
||||
chunk_length // 4,
|
||||
d_model,
|
||||
@ -1338,6 +1463,10 @@ class Emformer(EncoderInterface):
|
||||
right_context at the end.
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
|
||||
# TODO: The length computation in the encoder class should be moved here. # noqa
|
||||
U = x.size(1) - self.right_context_length // 4
|
||||
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
@ -1346,7 +1475,7 @@ class Emformer(EncoderInterface):
|
||||
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
||||
assert x.size(0) == x_lens.max().item()
|
||||
|
||||
output, output_lengths = self.encoder(x, x_lens) # (T, N, C)
|
||||
output, output_lengths = self.encoder(x, x_lens, pos_emb) # (T, N, C)
|
||||
|
||||
logits = self.encoder_output_layer(output)
|
||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
@ -1392,6 +1521,12 @@ class Emformer(EncoderInterface):
|
||||
- updated convolution caches from current chunk.
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
|
||||
# TODO: The length computation in the encoder class should be moved here. # noqa
|
||||
pos_len = self.chunk_length // 4 + self.left_context_length // 4
|
||||
neg_len = self.chunk_length // 4
|
||||
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
|
||||
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
@ -1405,7 +1540,7 @@ class Emformer(EncoderInterface):
|
||||
output_lengths,
|
||||
output_states,
|
||||
output_conv_caches,
|
||||
) = self.encoder.infer(x, x_lens, states, conv_caches)
|
||||
) = self.encoder.infer(x, x_lens, pos_emb, states, conv_caches)
|
||||
|
||||
logits = self.encoder_output_layer(output)
|
||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
@ -1533,6 +1668,111 @@ class ConvolutionModule(nn.Module):
|
||||
return x.permute(2, 0, 1), new_cache
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module.
|
||||
|
||||
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" # noqa
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py # noqa
|
||||
|
||||
Args:
|
||||
d_model: Embedding dimension.
|
||||
dropout_rate: Dropout rate.
|
||||
max_len: Maximum input length.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, d_model: int, dropout_rate: float, max_len: int = 5000
|
||||
) -> None:
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.pos_len = max_len
|
||||
self.neg_len = max_len
|
||||
self.gen_pe()
|
||||
|
||||
def gen_pe(self) -> None:
|
||||
"""Generate the positional encodings."""
|
||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||
# position of key vector. We use position relative positions when keys
|
||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||
pe_positive = torch.zeros(self.pos_len, self.d_model)
|
||||
pe_negative = torch.zeros(self.neg_len, self.d_model)
|
||||
position_positive = torch.arange(
|
||||
0, self.pos_len, dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
position_negative = torch.arange(
|
||||
0, self.neg_len, dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe_positive[:, 0::2] = torch.sin(position_positive * div_term)
|
||||
pe_positive[:, 1::2] = torch.cos(position_positive * div_term)
|
||||
pe_negative[:, 0::2] = torch.sin(-1 * position_negative * div_term)
|
||||
pe_negative[:, 1::2] = torch.cos(-1 * position_negative * div_term)
|
||||
|
||||
# Reserve the order of positive indices and concat both positive and
|
||||
# negative indices. This is used to support the shifting trick
|
||||
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" # noqa
|
||||
self.pe_positive = torch.flip(pe_positive, [0])
|
||||
self.pe_negative = pe_negative
|
||||
# self.pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
|
||||
def get_pe(
|
||||
self,
|
||||
pos_len: int,
|
||||
neg_len: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
"""Get positional encoding given positive length and negative length."""
|
||||
if self.pe_positive.dtype != dtype or str(
|
||||
self.pe_positive.device
|
||||
) != str(device):
|
||||
self.pe_positive = self.pe_positive.to(dtype=dtype, device=device)
|
||||
if self.pe_negative.dtype != dtype or str(
|
||||
self.pe_negative.device
|
||||
) != str(device):
|
||||
self.pe_negative = self.pe_negative.to(dtype=dtype, device=device)
|
||||
pe = torch.cat(
|
||||
[
|
||||
self.pe_positive[self.pos_len - pos_len :],
|
||||
self.pe_negative[1:neg_len],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
return pe
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
pos_len: int,
|
||||
neg_len: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
|
||||
"""
|
||||
x = x * self.xscale
|
||||
if pos_len > self.pos_len or neg_len > self.neg_len:
|
||||
self.pos_len = pos_len
|
||||
self.neg_len = neg_len
|
||||
self.gen_pe()
|
||||
pos_emb = self.get_pe(pos_len, neg_len, x.device, x.dtype)
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
"""Construct an Swish object."""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user