add emformer attention module

This commit is contained in:
yaozengwei 2022-05-13 17:07:40 +08:00
parent d74e2e8e07
commit 2cfb2f58f0
2 changed files with 573 additions and 0 deletions

View File

@ -151,3 +151,485 @@ class RelPositionalEncoding(torch.nn.Module):
self.gen_pe_negative()
pos_emb = self.get_pe(pos_len, neg_len, x.device, x.dtype)
return self.dropout(x), self.dropout(pos_emb)
class EmformerAttention(nn.Module):
r"""Emformer layer attention module.
Args:
embed_dim (int):
Embedding dimension.
nhead (int):
Number of attention heads in each Emformer layer.
dropout (float, optional):
Dropout probability. (Default: 0.0)
tanh_on_mem (bool, optional):
If ``True``, applies tanh to memory elements. (Default: ``False``)
negative_inf (float, optional):
Value to use for negative infinity in attention weights. (Default: -1e8)
"""
def __init__(
self,
embed_dim: int,
nhead: int,
dropout: float = 0.0,
tanh_on_mem: bool = False,
negative_inf: float = -1e8,
):
super().__init__()
if embed_dim % nhead != 0:
raise ValueError(
f"embed_dim ({embed_dim}) is not a multiple of"
f"nhead ({nhead})."
)
self.embed_dim = embed_dim
self.nhead = nhead
self.tanh_on_mem = tanh_on_mem
self.negative_inf = negative_inf
self.head_dim = embed_dim // nhead
self.dropout = dropout
self.emb_to_key_value = ScaledLinear(
embed_dim, 2 * embed_dim, bias=True
)
self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True)
self.out_proj = ScaledLinear(
embed_dim, embed_dim, bias=True, initial_scale=0.25
)
# linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa
self.pos_bias_u = nn.Parameter(torch.Tensor(nhead, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(nhead, self.head_dim))
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
self._reset_parameters()
def _pos_bias_u(self):
return self.pos_bias_u * self.pos_bias_u_scale.exp()
def _pos_bias_v(self):
return self.pos_bias_v * self.pos_bias_v_scale.exp()
def _reset_parameters(self) -> None:
nn.init.normal_(self.pos_bias_u, std=0.01)
nn.init.normal_(self.pos_bias_v, std=0.01)
def _gen_attention_probs(
self,
attention_weights: torch.Tensor,
attention_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor],
) -> torch.Tensor:
"""Given the entire attention weights, mask out unecessary connections
and optionally with padding positions, to obtain underlying chunk-wise
attention probabilities.
B: batch size;
Q: length of query;
KV: length of key and value.
Args:
attention_weights (torch.Tensor):
Attention weights computed on the entire concatenated tensor
with shape (B * nhead, Q, KV).
attention_mask (torch.Tensor):
Mask tensor where chunk-wise connections are filled with `False`,
and other unnecessary connections are filled with `True`,
with shape (Q, KV).
padding_mask (torch.Tensor, optional):
Mask tensor where the padding positions are fill with `True`,
and other positions are filled with `False`, with shapa `(B, KV)`.
Returns:
A tensor of shape (B * nhead, Q, KV).
"""
attention_weights_float = attention_weights.float()
attention_weights_float = attention_weights_float.masked_fill(
attention_mask.unsqueeze(0), self.negative_inf
)
if padding_mask is not None:
Q = attention_weights.size(1)
B = attention_weights.size(0) // self.nhead
attention_weights_float = attention_weights_float.view(
B, self.nhead, Q, -1
)
attention_weights_float = attention_weights_float.masked_fill(
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
self.negative_inf,
)
attention_weights_float = attention_weights_float.view(
B * self.nhead, Q, -1
)
attention_probs = nn.functional.softmax(
attention_weights_float, dim=-1
).type_as(attention_weights)
attention_probs = nn.functional.dropout(
attention_probs, p=self.dropout, training=self.training
)
return attention_probs
def _rel_shift(self, x: torch.Tensor) -> torch.Tensor:
"""Compute relative positional encoding.
Args:
x: Input tensor, of shape (B, nhead, U, PE).
U is the length of query vector.
For training mode, PE = 2 * U - 1;
for inference mode, PE = L + 2 * U - 1.
Returns:
A tensor of shape (B, nhead, U, out_len).
For non-infer mode, out_len = U;
for infer mode, out_len = L + U.
"""
B, nhead, U, PE = x.size()
B_stride = x.stride(0)
nhead_stride = x.stride(1)
U_stride = x.stride(2)
PE_stride = x.stride(3)
out_len = PE - (U - 1)
return x.as_strided(
size=(B, nhead, U, out_len),
stride=(B_stride, nhead_stride, U_stride - PE_stride, PE_stride),
storage_offset=PE_stride * (U - 1),
)
def _forward_impl(
self,
utterance: torch.Tensor,
lengths: torch.Tensor,
right_context: torch.Tensor,
summary: torch.Tensor,
memory: torch.Tensor,
attention_mask: torch.Tensor,
pos_emb: torch.Tensor,
left_context_key: Optional[torch.Tensor] = None,
left_context_val: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Underlying chunk-wise attention implementation."""
U, B, _ = utterance.size()
R = right_context.size(0)
M = memory.size(0)
scaling = float(self.head_dim) ** -0.5
# compute query with [right_context, utterance, summary].
query = self.emb_to_query(
torch.cat([right_context, utterance, summary])
)
# compute key and value with [memory, right_context, utterance].
key, value = self.emb_to_key_value(
torch.cat([memory, right_context, utterance])
).chunk(chunks=2, dim=2)
if left_context_key is not None and left_context_val is not None:
# now compute key and value with
# [memory, right context, left context, uttrance]
# this is used in inference mode
key = torch.cat([key[: M + R], left_context_key, key[M + R :]])
value = torch.cat(
[value[: M + R], left_context_val, value[M + R :]]
)
Q = query.size(0)
KV = key.size(0)
reshaped_key, reshaped_value = [
tensor.contiguous()
.view(KV, B * self.nhead, self.head_dim)
.transpose(0, 1)
for tensor in [key, value]
] # both of shape (B * nhead, KV, head_dim)
reshaped_query = (
query.contiguous().view(Q, B, self.nhead, self.head_dim) * scaling
)
# compute attention score
# first, compute attention matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa
query_with_bais_u = (
(reshaped_query + self._pos_bias_u())
.view(Q, B * self.nhead, self.head_dim)
.transpose(0, 1)
) # (B * nhead, Q, head_dim)
matrix_ac = torch.bmm(
query_with_bais_u, reshaped_key.transpose(1, 2)
) # (B * nhead, Q, KV)
# second, compute attention matrix b and matrix d
# relative positional encoding is applied on the part of attention
# between chunk (in query) and itself as well as its left context
# (in key)
utterance_with_bais_v = (
reshaped_query[R : R + U] + self._pos_bias_v()
).permute(1, 2, 0, 3)
# (B, nhead, U, head_dim)
PE = pos_emb.size(0)
if left_context_key is not None and left_context_val is not None:
# inference mode
L = left_context_key.size(0)
assert PE == L + 2 * U - 1
else:
# training mode
assert PE == 2 * U - 1
pos_emb = (
self.linear_pos(pos_emb)
.view(PE, self.nhead, self.head_dim)
.transpose(0, 1)
.unsqueeze(0)
) # (1, nhead, PE, head_dim)
matrix_bd_utterance = torch.matmul(
utterance_with_bais_v, pos_emb.transpose(-2, -1)
) # (B, nhead, U, PE)
# rel-shift operation
matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
# (B, nhead, U, U) for training mode;
# (B, nhead, U, L + U) for inference mode.
matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
B * self.nhead, U, -1
)
matrix_bd = torch.zeros_like(matrix_ac)
matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance
attention_weights = matrix_ac + matrix_bd
# compute padding mask
if B == 1:
padding_mask = None
else:
padding_mask = make_pad_mask(KV - U + lengths)
# compute attention probabilities
attention_probs = self._gen_attention_probs(
attention_weights, attention_mask, padding_mask
)
# compute attention outputs
attention = torch.bmm(attention_probs, reshaped_value)
assert attention.shape == (B * self.nhead, Q, self.head_dim)
attention = (
attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim)
)
# apply output projection
outputs = self.out_proj(attention)
output_right_context_utterance = outputs[: R + U]
output_memory = outputs[R + U :]
if self.tanh_on_mem:
output_memory = torch.tanh(output_memory)
else:
output_memory = torch.clamp(output_memory, min=-10, max=10)
return output_right_context_utterance, output_memory, key, value
def forward(
self,
utterance: torch.Tensor,
lengths: torch.Tensor,
right_context: torch.Tensor,
summary: torch.Tensor,
memory: torch.Tensor,
attention_mask: torch.Tensor,
pos_emb: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO: Modify docs.
"""Forward pass for training mode.
B: batch size;
D: embedding dimension;
R: length of the hard-copied right contexts;
U: length of full utterance;
S: length of summary vectors;
M: length of memory vectors.
It computes a `big` attention matrix on full utterance and
then utilizes a pre-computed mask to simulate chunk-wise attention.
It concatenates three blocks: hard-copied right contexts,
full utterance, and summary vectors, as a `big` block,
to compute the query tensor:
query = [right_context, utterance, summary],
with length Q = R + U + S.
It concatenates the three blocks: memory vectors,
hard-copied right contexts, and full utterance as another `big` block,
to compute the key and value tensors:
key & value = [memory, right_context, utterance],
with length KV = M + R + U.
Attention scores is computed with above `big` query and key.
Then the underlying chunk-wise attention is obtained by applying
the attention mask. Suppose
c_i: chunk at index i;
r_i: right context that c_i can use;
l_i: left context that c_i can use;
m_i: past memory vectors from previous layer that c_i can use;
s_i: summary vector of c_i;
The target chunk-wise attention is:
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key);
s_i (in query) -> l_i, c_i, r_i (in key).
Relative positional encoding is applied on the part of attention between
utterance (in query) and utterance (in key). Actually, it is applied on
the part of attention between each chunk (in query) and itself as well
as its left context (in key), after applying the mask:
c_i -> l_i, c_i.
Args:
utterance (torch.Tensor):
Full utterance frames, with shape (U, B, D).
lengths (torch.Tensor):
With shape (B,) and i-th element representing
number of valid frames for i-th batch element in utterance.
right_context (torch.Tensor):
Hard-copied right context frames, with shape (R, B, D),
where R = num_chunks * right_context_length
summary (torch.Tensor):
Summary elements with shape (S, B, D), where S = num_chunks.
It is an empty tensor without using memory.
memory (torch.Tensor):
Memory elements, with shape (M, B, D), where M = num_chunks - 1.
It is an empty tensor without using memory.
attention_mask (torch.Tensor):
Pre-computed attention mask to simulate underlying chunk-wise
attention, with shape (Q, KV).
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D).
where PE = 2 * U - 1.
Returns:
A tuple containing 2 tensors:
- output of right context and utterance, with shape (R + U, B, D).
- memory output, with shape (M, B, D), where M = S - 1 or M = 0.
"""
(
output_right_context_utterance,
output_memory,
_,
_,
) = self._forward_impl(
utterance,
lengths,
right_context,
summary,
memory,
attention_mask,
pos_emb,
)
return output_right_context_utterance, output_memory[:-1]
def infer(
self,
utterance: torch.Tensor,
lengths: torch.Tensor,
right_context: torch.Tensor,
summary: torch.Tensor,
memory: torch.Tensor,
left_context_key: torch.Tensor,
left_context_val: torch.Tensor,
pos_emb: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward pass for inference.
B: batch size;
D: embedding dimension;
R: length of right context;
U: length of utterance, i.e., current chunk;
L: length of cached left context;
S: length of summary vectors, S = 1;
M: length of cached memory vectors.
It concatenates the right context, utterance (i.e., current chunk)
and summary vector of current chunk, to compute the query tensor:
query = [right_context, utterance, summary],
with length Q = R + U + S.
It concatenates the memory vectors, right context, left context, and
current chunk, to compute the key and value tensors:
key & value = [memory, right_context, left_context, utterance],
with length KV = M + R + L + U.
The chunk-wise attention is:
chunk, right context (in query) ->
left context, chunk, right context, memory vectors (in key);
summary (in query) -> left context, chunk, right context (in key).
Relative positional encoding is applied on the part of attention between
chunk (in query) and chunk itself as well as its left context (in key):
chunk (in query) -> left context, chunk (in key).
Args:
utterance (torch.Tensor):
Current chunk frames, with shape (U, B, D), where U = chunk_length.
lengths (torch.Tensor):
With shape (B,) and i-th element representing
number of valid frames for i-th batch element in utterance.
right_context (torch.Tensor):
Right context frames, with shape (R, B, D),
where R = right_context_length.
summary (torch.Tensor):
Summary vector with shape (1, B, D), or empty tensor.
memory (torch.Tensor):
Memory vectors, with shape (M, B, D), or empty tensor.
left_context_key (torch,Tensor):
Cached attention key of left context from preceding computation,
with shape (L, B, D), where L <= left_context_length.
left_context_val (torch.Tensor):
Cached attention value of left context from preceding computation,
with shape (L, B, D), where L <= left_context_length.
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D),
where PE = L + 2 * U - 1.
Returns:
A tuple containing 4 tensors:
- output of right context and utterance, with shape (R + U, B, D).
- memory output, with shape (1, B, D) or (0, B, D).
- attention key of left context and utterance, which would be cached
for next computation, with shape (L + U, B, D).
- attention value of left context and utterance, which would be
cached for next computation, with shape (L + U, B, D).
"""
U = utterance.size(0)
R = right_context.size(0)
L = left_context_key.size(0)
S = summary.size(0)
M = memory.size(0)
# query = [right context, utterance, summary]
Q = R + U + S
# key, value = [memory, right context, left context, uttrance]
KV = M + R + L + U
attention_mask = torch.zeros(Q, KV).to(
dtype=torch.bool, device=utterance.device
)
# disallow attention bettween the summary vector with the memory bank
attention_mask[-1, :M] = True
(
output_right_context_utterance,
output_memory,
key,
value,
) = self._forward_impl(
utterance,
lengths,
right_context,
summary,
memory,
attention_mask,
pos_emb,
left_context_key=left_context_key,
left_context_val=left_context_val,
)
return (
output_right_context_utterance,
output_memory,
key[M + R :],
value[M + R :],
)

View File

@ -13,5 +13,96 @@ def test_rel_positional_encoding():
assert pos_emb.shape == (pos_len + neg_len - 1, D)
def test_emformer_attention_forward():
from emformer import EmformerAttention
B, D = 2, 256
chunk_length = 4
right_context_length = 2
num_chunks = 3
U = num_chunks * chunk_length
R = num_chunks * right_context_length
attention = EmformerAttention(embed_dim=D, nhead=8)
for use_memory in [True, False]:
if use_memory:
S = num_chunks
M = S - 1
else:
S, M = 0, 0
Q, KV = R + U + S, M + R + U
utterance = torch.randn(U, B, D)
lengths = torch.randint(1, U + 1, (B,))
lengths[0] = U
right_context = torch.randn(R, B, D)
summary = torch.randn(S, B, D)
memory = torch.randn(M, B, D)
attention_mask = torch.rand(Q, KV) >= 0.5
PE = 2 * U - 1
pos_emb = torch.randn(PE, D)
output_right_context_utterance, output_memory = attention(
utterance,
lengths,
right_context,
summary,
memory,
attention_mask,
pos_emb,
)
assert output_right_context_utterance.shape == (R + U, B, D)
assert output_memory.shape == (M, B, D)
def test_emformer_attention_infer():
from emformer import EmformerAttention
B, D = 2, 256
U = 4
R = 2
L = 3
attention = EmformerAttention(embed_dim=D, nhead=8)
for use_memory in [True, False]:
if use_memory:
S, M = 1, 3
else:
S, M = 0, 0
utterance = torch.randn(U, B, D)
lengths = torch.randint(1, U + 1, (B,))
lengths[0] = U
right_context = torch.randn(R, B, D)
summary = torch.randn(S, B, D)
memory = torch.randn(M, B, D)
left_context_key = torch.randn(L, B, D)
left_context_val = torch.randn(L, B, D)
PE = L + 2 * U - 1
pos_emb = torch.randn(PE, D)
(
output_right_context_utterance,
output_memory,
next_key,
next_val,
) = attention.infer(
utterance,
lengths,
right_context,
summary,
memory,
left_context_key,
left_context_val,
pos_emb,
)
assert output_right_context_utterance.shape == (R + U, B, D)
assert output_memory.shape == (S, B, D)
assert next_key.shape == (L + U, B, D)
assert next_val.shape == (L + U, B, D)
if __name__ == "__main__":
test_rel_positional_encoding()
test_emformer_attention_forward()
test_emformer_attention_infer()