add pos enc on memory, left context, chunk, and right context (in key)

This commit is contained in:
yaozengwei 2022-05-20 11:13:29 +08:00
parent f94ad976c6
commit 19a8700301
2 changed files with 242 additions and 58 deletions

View File

@ -450,6 +450,10 @@ class EmformerAttention(nn.Module):
Embedding dimension.
nhead (int):
Number of attention heads in each Emformer layer.
chunk_length (int):
Length of each input chunk.
right_context_length (int):
Length of right context.
dropout (float, optional):
Dropout probability. (Default: 0.0)
tanh_on_mem (bool, optional):
@ -462,6 +466,8 @@ class EmformerAttention(nn.Module):
self,
embed_dim: int,
nhead: int,
chunk_length: int,
right_context_length: int,
dropout: float = 0.0,
tanh_on_mem: bool = False,
negative_inf: float = -1e8,
@ -479,6 +485,8 @@ class EmformerAttention(nn.Module):
self.tanh_on_mem = tanh_on_mem
self.negative_inf = negative_inf
self.head_dim = embed_dim // nhead
self.chunk_length = chunk_length
self.right_context_length = right_context_length
self.dropout = dropout
self.emb_to_key_value = ScaledLinear(
@ -572,13 +580,16 @@ class EmformerAttention(nn.Module):
Args:
x: Input tensor, of shape (B, nhead, U, PE).
U is the length of query vector.
For training and validation mode, PE = 2 * U - 1;
for inference mode, PE = L + 2 * U - 1.
For training and validation mode,
PE = 2 * U + right_context_length - 1.
For inference mode,
PE = tot_left_length + 2 * U + right_context_length - 1,
where tot_left_length = M * chunk_length.
Returns:
A tensor of shape (B, nhead, U, out_len).
For non-infer mode, out_len = U;
for infer mode, out_len = L + U.
For training and validation mode, out_len = U + right_context_length.
For inference mode, out_len = tot_left_length + U + right_context_length. # noqa
"""
B, nhead, U, PE = x.size()
B_stride = x.stride(0)
@ -592,6 +603,33 @@ class EmformerAttention(nn.Module):
storage_offset=PE_stride * (U - 1),
)
def _get_right_context_part(
self, matrix_bd_utterance: torch.Tensor
) -> torch.Tensor:
"""
Args:
matrix_bd_utterance:
(B * nhead, U, U + right_context_length)
Returns:
A tensor of shape (B * nhead, U, R),
where R = num_chunks * right_context_length.
"""
assert self.right_context_length > 0
U = matrix_bd_utterance.size(1)
num_chunks = math.ceil(U / self.chunk_length)
right_context_blocks = []
for i in range(num_chunks - 1):
start_idx = (i + 1) * self.chunk_length
end_idx = start_idx + self.right_context_length
right_context_blocks.append(
matrix_bd_utterance[:, :, start_idx:end_idx]
)
right_context_blocks.append(
matrix_bd_utterance[:, :, -self.right_context_length :]
)
return torch.cat(right_context_blocks, dim=2)
def _forward_impl(
self,
utterance: torch.Tensor,
@ -603,7 +641,15 @@ class EmformerAttention(nn.Module):
pos_emb: torch.Tensor,
left_context_key: Optional[torch.Tensor] = None,
left_context_val: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
need_weights=False,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""Underlying chunk-wise attention implementation."""
U, B, _ = utterance.size()
R = right_context.size(0)
@ -664,10 +710,12 @@ class EmformerAttention(nn.Module):
if left_context_key is not None and left_context_val is not None:
# inference mode
L = left_context_key.size(0)
assert PE == L + 2 * U - 1
tot_left_length = M * self.chunk_length if M > 0 else L
assert tot_left_length >= L
assert PE == tot_left_length + 2 * U + self.right_context_length - 1
else:
# training and validation mode
assert PE == 2 * U - 1
assert PE == 2 * U + self.right_context_length - 1
pos_emb = (
self.linear_pos(pos_emb)
.view(PE, self.nhead, self.head_dim)
@ -679,13 +727,49 @@ class EmformerAttention(nn.Module):
) # (B, nhead, U, PE)
# rel-shift operation
matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
# (B, nhead, U, U) for training and validation mode;
# (B, nhead, U, L + U) for inference mode.
# (B, nhead, U, U + right_context_length) for training and validation mode; # noqa
# (B, nhead, U, tot_left_length + U + right_context_length) for inference mode. # noqa
matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
B * self.nhead, U, -1
)
matrix_bd = torch.zeros_like(matrix_ac)
matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance
if left_context_key is not None and left_context_val is not None:
# inference mode
# key: [memory, right context, left context, utterance]
# for memory
if M > 0:
matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d(
matrix_bd_utterance[:, :, :tot_left_length].unsqueeze(1),
kernel_size=(1, self.chunk_length),
stride=(1, self.chunk_length),
).squeeze(1)
# for right_context
if R > 0:
matrix_bd[:, R : R + U, M : M + R] = matrix_bd_utterance[
:, :, tot_left_length + U :
]
# for left_context and utterance
matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance[
:, :, tot_left_length - L : tot_left_length + U
]
else:
# training and validation mode
# key: [memory, right context, utterance]
# for memory
if M > 0:
matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d(
matrix_bd_utterance[:, :, :U].unsqueeze(1),
kernel_size=(1, self.chunk_length),
stride=(1, self.chunk_length),
ceil_mode=True,
).squeeze(1)[:, :, :-1]
# for right_context
if R > 0:
matrix_bd[
:, R : R + U, M : M + R
] = self._get_right_context_part(matrix_bd_utterance)
# for utterance
matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance[:, :, :U]
attention_weights = matrix_ac + matrix_bd
@ -717,7 +801,29 @@ class EmformerAttention(nn.Module):
else:
output_memory = torch.clamp(output_memory, min=-10, max=10)
return output_right_context_utterance, output_memory, key, value
if need_weights:
# average over attention heads
attention_probs = attention_probs.reshape(B, self.nhead, Q, KV)
attention_probs = attention_probs.sum(dim=1) / self.nhead
probs_memory = attention_probs[:, R : R + U, :M].sum(dim=2)
probs_frames = attention_probs[:, R : R + U, M:].sum(dim=2)
return (
output_right_context_utterance,
output_memory,
key,
value,
probs_memory,
probs_frames,
)
return (
output_right_context_utterance,
output_memory,
key,
value,
None,
None,
)
def forward(
self,
@ -766,10 +872,11 @@ class EmformerAttention(nn.Module):
s_i (in query) -> l_i, c_i, r_i (in key).
Relative positional encoding is applied on the part of attention between
utterance (in query) and utterance (in key). Actually, it is applied on
the part of attention between each chunk (in query) and itself as well
as its left context (in key), after applying the mask:
c_i -> l_i, c_i.
[utterance] (in query) and [memory, right_context, utterance] (in key).
Actually, it is applied on the part of attention between each chunk
(in query) and itself, its memory vectors, left context, and right
context (in key), after applying the mask:
c_i (in query) -> l_i, c_i, r_i, m_i (in key).
Args:
utterance (torch.Tensor):
@ -791,18 +898,23 @@ class EmformerAttention(nn.Module):
attention, with shape (Q, KV).
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D).
where PE = 2 * U - 1.
where PE = 2 * U + right_context_length - 1.
Returns:
A tuple containing 2 tensors:
- output of right context and utterance, with shape (R + U, B, D).
- memory output, with shape (M, B, D), where M = S - 1 or M = 0.
- summary of attention weights on memory, with shape (B, U).
- summary of attention weights on left context, utterance, and
right context, with shape (B, U).
"""
(
output_right_context_utterance,
output_memory,
_,
_,
probs_memory,
probs_frames,
) = self._forward_impl(
utterance,
lengths,
@ -811,8 +923,14 @@ class EmformerAttention(nn.Module):
memory,
attention_mask,
pos_emb,
need_weights=True,
)
return (
output_right_context_utterance,
output_memory[:-1],
probs_memory,
probs_frames,
)
return output_right_context_utterance, output_memory[:-1]
def infer(
self,
@ -849,9 +967,9 @@ class EmformerAttention(nn.Module):
left context, chunk, right context, memory vectors (in key);
summary (in query) -> left context, chunk, right context (in key).
Relative positional encoding is applied on the part of attention between
chunk (in query) and chunk itself as well as its left context (in key):
chunk (in query) -> left context, chunk (in key).
Relative positional encoding is applied on the part of attention:
chunk (in query) ->
left context, chunk, right context, memory vectors (in key);
Args:
utterance (torch.Tensor):
@ -874,7 +992,7 @@ class EmformerAttention(nn.Module):
with shape (L, B, D), where L <= left_context_length.
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D),
where PE = L + 2 * U - 1.
where PE = M * chunk_length + 2 * U - 1 if M > 0 else L + 2 * U - 1.
Returns:
A tuple containing 4 tensors:
@ -905,6 +1023,8 @@ class EmformerAttention(nn.Module):
output_memory,
key,
value,
_,
_,
) = self._forward_impl(
utterance,
lengths,
@ -974,6 +1094,8 @@ class EmformerEncoderLayer(nn.Module):
self.attention = EmformerAttention(
embed_dim=d_model,
nhead=nhead,
chunk_length=chunk_length,
right_context_length=right_context_length,
dropout=dropout,
tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf,
@ -1018,6 +1140,7 @@ class EmformerEncoderLayer(nn.Module):
self.layer_dropout = layer_dropout
self.left_context_length = left_context_length
self.chunk_length = chunk_length
self.right_context_length = right_context_length
self.max_memory_size = max_memory_size
self.d_model = d_model
self.use_memory = max_memory_size > 0
@ -1140,7 +1263,7 @@ class EmformerEncoderLayer(nn.Module):
summary = torch.empty(0).to(
dtype=utterance.dtype, device=utterance.device
)
output_right_context_utterance, output_memory = self.attention(
output_right_context_utterance, output_memory, _, _ = self.attention(
utterance=utterance,
lengths=lengths,
right_context=right_context,
@ -1190,14 +1313,26 @@ class EmformerEncoderLayer(nn.Module):
summary = torch.empty(0).to(
dtype=utterance.dtype, device=utterance.device
)
# pos_emb is of shape [PE, D], where PE = L + 2 * U - 1,
# for query of [utterance] (i), key-value [left_context, utterance] (j),
# the max relative distance i - j is L + U - 1
# the min relative distance i - j is -(U - 1)
L = left_context_key.size(0) # L <= left_context_length
U = utterance.size(0)
PE = L + 2 * U - 1
tot_PE = self.left_context_length + 2 * U - 1
# pos_emb is of shape [PE, D], where PE = M * chunk_length + 2 * U - 1,
# for query of [utterance] (i), key-value [memory vectors, left context, utterance, right context] (j) # noqa
# the max relative distance i - j is M * chunk_length + U - 1
# the min relative distance i - j is -(U + right_context_length - 1)
M = pre_memory.size(0) # M <= max_memory_size
if self.max_memory_size > 0:
PE = M * self.chunk_length + 2 * U + self.right_context_length - 1
tot_PE = (
self.max_memory_size * self.chunk_length
+ 2 * U
+ self.right_context_length
- 1
)
else:
L = left_context_key.size(0)
PE = L + 2 * U + self.right_context_length - 1
tot_PE = (
self.left_context_length + 2 * U + self.right_context_length - 1
)
assert pos_emb.size(0) == tot_PE
pos_emb = pos_emb[tot_PE - PE :]
(
@ -1661,7 +1796,9 @@ class EmformerEncoder(nn.Module):
right_context at the end.
"""
U = x.size(0) - self.right_context_length
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
x, pos_emb = self.encoder_pos(
x, pos_len=U, neg_len=U + self.right_context_length
)
right_context = self._gen_right_context(x)
utterance = x[:U]
@ -1734,8 +1871,12 @@ class EmformerEncoder(nn.Module):
f"for dimension 1 of x, but got {x.size(1)}."
)
pos_len = self.chunk_length + self.left_context_length
neg_len = self.chunk_length
pos_len = (
self.max_memory_size * self.chunk_length + self.chunk_length
if self.max_memory_size > 0
else self.left_context_length + self.chunk_length
)
neg_len = self.chunk_length + self.right_context_length
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
right_context_start_idx = x.size(0) - self.right_context_length
@ -1807,6 +1948,13 @@ class Emformer(EncoderInterface):
raise NotImplementedError(
"right_context_length must be 0 or a mutiple of 4."
)
if (
max_memory_size > 0
and max_memory_size * chunk_length < left_context_length
):
raise NotImplementedError(
"max_memory_size * chunk_length can not be less than left_context_length" # noqa
)
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).

View File

@ -22,7 +22,12 @@ def test_emformer_attention_forward():
num_chunks = 3
U = num_chunks * chunk_length
R = num_chunks * right_context_length
attention = EmformerAttention(embed_dim=D, nhead=8)
attention = EmformerAttention(
embed_dim=D,
nhead=8,
chunk_length=chunk_length,
right_context_length=right_context_length,
)
for use_memory in [True, False]:
if use_memory:
@ -39,10 +44,15 @@ def test_emformer_attention_forward():
summary = torch.randn(S, B, D)
memory = torch.randn(M, B, D)
attention_mask = torch.rand(Q, KV) >= 0.5
PE = 2 * U - 1
PE = 2 * U + right_context_length - 1
pos_emb = torch.randn(PE, D)
output_right_context_utterance, output_memory = attention(
(
output_right_context_utterance,
output_memory,
probs_memory,
probs_frames,
) = attention(
utterance,
lengths,
right_context,
@ -53,16 +63,26 @@ def test_emformer_attention_forward():
)
assert output_right_context_utterance.shape == (R + U, B, D)
assert output_memory.shape == (M, B, D)
assert probs_memory.shape == (B, U)
assert probs_frames.shape == (B, U)
def test_emformer_attention_infer():
from emformer import EmformerAttention
B, D = 2, 256
U = 4
R = 2
chunk_length = 4
right_context_length = 2
num_chunks = 1
U = chunk_length * num_chunks
R = right_context_length * num_chunks
L = 3
attention = EmformerAttention(embed_dim=D, nhead=8)
attention = EmformerAttention(
embed_dim=D,
nhead=8,
chunk_length=chunk_length,
right_context_length=right_context_length,
)
for use_memory in [True, False]:
if use_memory:
@ -78,7 +98,12 @@ def test_emformer_attention_infer():
memory = torch.randn(M, B, D)
left_context_key = torch.randn(L, B, D)
left_context_val = torch.randn(L, B, D)
PE = L + 2 * U - 1
PE = (
2 * U
+ right_context_length
- 1
+ (M * chunk_length if M > 0 else L)
)
pos_emb = torch.randn(PE, D)
(
@ -197,7 +222,7 @@ def test_emformer_encoder_layer_forward():
right_context = torch.randn(R, B, D)
memory = torch.randn(M, B, D)
attention_mask = torch.rand(Q, KV) >= 0.5
PE = 2 * U - 1
PE = 2 * U + right_context_length - 1
pos_emb = torch.randn(PE, D)
output_utterance, output_right_context, output_memory = layer(
@ -227,8 +252,10 @@ def test_emformer_encoder_layer_infer():
for use_memory in [True, False]:
if use_memory:
M = 3
max_memory_size = 3
M = 1
else:
max_memory_size = 0
M = 0
layer = EmformerEncoderLayer(
@ -239,7 +266,7 @@ def test_emformer_encoder_layer_infer():
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
max_memory_size=max_memory_size,
)
utterance = torch.randn(U, B, D)
@ -248,7 +275,16 @@ def test_emformer_encoder_layer_infer():
right_context = torch.randn(R, B, D)
memory = torch.randn(M, B, D)
state = None
PE = left_context_length + 2 * U - 1
PE = (
2 * U
+ right_context_length
- 1
+ (
max_memory_size * chunk_length
if max_memory_size > 0
else left_context_length
)
)
pos_emb = torch.randn(PE, D)
conv_cache = None
(
@ -273,7 +309,7 @@ def test_emformer_encoder_layer_infer():
else:
assert output_memory.shape == (0, B, D)
assert len(output_state) == 4
assert output_state[0].shape == (M, B, D)
assert output_state[0].shape == (max_memory_size, B, D)
assert output_state[1].shape == (left_context_length, B, D)
assert output_state[2].shape == (left_context_length, B, D)
assert output_state[3].shape == (1, B)
@ -334,9 +370,9 @@ def test_emformer_encoder_infer():
for use_memory in [True, False]:
if use_memory:
M = 3
max_memory_size = 3
else:
M = 0
max_memory_size = 0
encoder = EmformerEncoder(
chunk_length=chunk_length,
@ -346,7 +382,7 @@ def test_emformer_encoder_infer():
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
max_memory_size=max_memory_size,
)
states = None
@ -368,7 +404,7 @@ def test_emformer_encoder_infer():
assert len(states) == num_encoder_layers
for state in states:
assert len(state) == 4
assert state[0].shape == (M, B, D)
assert state[0].shape == (max_memory_size, B, D)
assert state[1].shape == (left_context_length, B, D)
assert state[2].shape == (left_context_length, B, D)
assert torch.equal(
@ -391,7 +427,7 @@ def test_emformer_encoder_forward_infer_consistency():
kernel_size = 31
memory_sizes = [0, 3]
for M in memory_sizes:
for max_memory_size in memory_sizes:
encoder = EmformerEncoder(
chunk_length=chunk_length,
d_model=D,
@ -400,7 +436,7 @@ def test_emformer_encoder_forward_infer_consistency():
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
max_memory_size=max_memory_size,
)
encoder.eval()
@ -449,9 +485,9 @@ def test_emformer_forward():
for use_memory in [True, False]:
if use_memory:
M = 3
max_memory_size = 3
else:
M = 0
max_memory_size = 0
model = Emformer(
num_features=num_features,
chunk_length=chunk_length,
@ -460,7 +496,7 @@ def test_emformer_forward():
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
max_memory_size=max_memory_size,
)
x = torch.randn(B, U + right_context_length + 3, num_features)
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
@ -481,7 +517,7 @@ def test_emformer_infer():
num_features = 80
chunk_length = 8
U = chunk_length
left_context_length, right_context_length = 128, 4
left_context_length, right_context_length = 32, 4
B, D = 2, 256
num_chunks = 3
num_encoder_layers = 2
@ -489,9 +525,9 @@ def test_emformer_infer():
for use_memory in [True, False]:
if use_memory:
M = 3
max_memory_size = 32
else:
M = 0
max_memory_size = 0
model = Emformer(
num_features=num_features,
chunk_length=chunk_length,
@ -501,7 +537,7 @@ def test_emformer_infer():
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
max_memory_size=max_memory_size,
)
states = None
conv_caches = None
@ -523,7 +559,7 @@ def test_emformer_infer():
assert len(states) == num_encoder_layers
for state in states:
assert len(state) == 4
assert state[0].shape == (M, B, D)
assert state[0].shape == (max_memory_size, B, D)
assert state[1].shape == (left_context_length // 4, B, D)
assert state[2].shape == (left_context_length // 4, B, D)
assert torch.equal(