mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
add pos enc on memory, left context, chunk, and right context (in key)
This commit is contained in:
parent
f94ad976c6
commit
19a8700301
@ -450,6 +450,10 @@ class EmformerAttention(nn.Module):
|
|||||||
Embedding dimension.
|
Embedding dimension.
|
||||||
nhead (int):
|
nhead (int):
|
||||||
Number of attention heads in each Emformer layer.
|
Number of attention heads in each Emformer layer.
|
||||||
|
chunk_length (int):
|
||||||
|
Length of each input chunk.
|
||||||
|
right_context_length (int):
|
||||||
|
Length of right context.
|
||||||
dropout (float, optional):
|
dropout (float, optional):
|
||||||
Dropout probability. (Default: 0.0)
|
Dropout probability. (Default: 0.0)
|
||||||
tanh_on_mem (bool, optional):
|
tanh_on_mem (bool, optional):
|
||||||
@ -462,6 +466,8 @@ class EmformerAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
nhead: int,
|
nhead: int,
|
||||||
|
chunk_length: int,
|
||||||
|
right_context_length: int,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
@ -479,6 +485,8 @@ class EmformerAttention(nn.Module):
|
|||||||
self.tanh_on_mem = tanh_on_mem
|
self.tanh_on_mem = tanh_on_mem
|
||||||
self.negative_inf = negative_inf
|
self.negative_inf = negative_inf
|
||||||
self.head_dim = embed_dim // nhead
|
self.head_dim = embed_dim // nhead
|
||||||
|
self.chunk_length = chunk_length
|
||||||
|
self.right_context_length = right_context_length
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
self.emb_to_key_value = ScaledLinear(
|
self.emb_to_key_value = ScaledLinear(
|
||||||
@ -572,13 +580,16 @@ class EmformerAttention(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
x: Input tensor, of shape (B, nhead, U, PE).
|
x: Input tensor, of shape (B, nhead, U, PE).
|
||||||
U is the length of query vector.
|
U is the length of query vector.
|
||||||
For training and validation mode, PE = 2 * U - 1;
|
For training and validation mode,
|
||||||
for inference mode, PE = L + 2 * U - 1.
|
PE = 2 * U + right_context_length - 1.
|
||||||
|
For inference mode,
|
||||||
|
PE = tot_left_length + 2 * U + right_context_length - 1,
|
||||||
|
where tot_left_length = M * chunk_length.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tensor of shape (B, nhead, U, out_len).
|
A tensor of shape (B, nhead, U, out_len).
|
||||||
For non-infer mode, out_len = U;
|
For training and validation mode, out_len = U + right_context_length.
|
||||||
for infer mode, out_len = L + U.
|
For inference mode, out_len = tot_left_length + U + right_context_length. # noqa
|
||||||
"""
|
"""
|
||||||
B, nhead, U, PE = x.size()
|
B, nhead, U, PE = x.size()
|
||||||
B_stride = x.stride(0)
|
B_stride = x.stride(0)
|
||||||
@ -592,6 +603,33 @@ class EmformerAttention(nn.Module):
|
|||||||
storage_offset=PE_stride * (U - 1),
|
storage_offset=PE_stride * (U - 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_right_context_part(
|
||||||
|
self, matrix_bd_utterance: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
matrix_bd_utterance:
|
||||||
|
(B * nhead, U, U + right_context_length)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (B * nhead, U, R),
|
||||||
|
where R = num_chunks * right_context_length.
|
||||||
|
"""
|
||||||
|
assert self.right_context_length > 0
|
||||||
|
U = matrix_bd_utterance.size(1)
|
||||||
|
num_chunks = math.ceil(U / self.chunk_length)
|
||||||
|
right_context_blocks = []
|
||||||
|
for i in range(num_chunks - 1):
|
||||||
|
start_idx = (i + 1) * self.chunk_length
|
||||||
|
end_idx = start_idx + self.right_context_length
|
||||||
|
right_context_blocks.append(
|
||||||
|
matrix_bd_utterance[:, :, start_idx:end_idx]
|
||||||
|
)
|
||||||
|
right_context_blocks.append(
|
||||||
|
matrix_bd_utterance[:, :, -self.right_context_length :]
|
||||||
|
)
|
||||||
|
return torch.cat(right_context_blocks, dim=2)
|
||||||
|
|
||||||
def _forward_impl(
|
def _forward_impl(
|
||||||
self,
|
self,
|
||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
@ -603,7 +641,15 @@ class EmformerAttention(nn.Module):
|
|||||||
pos_emb: torch.Tensor,
|
pos_emb: torch.Tensor,
|
||||||
left_context_key: Optional[torch.Tensor] = None,
|
left_context_key: Optional[torch.Tensor] = None,
|
||||||
left_context_val: Optional[torch.Tensor] = None,
|
left_context_val: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
need_weights=False,
|
||||||
|
) -> Tuple[
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
]:
|
||||||
"""Underlying chunk-wise attention implementation."""
|
"""Underlying chunk-wise attention implementation."""
|
||||||
U, B, _ = utterance.size()
|
U, B, _ = utterance.size()
|
||||||
R = right_context.size(0)
|
R = right_context.size(0)
|
||||||
@ -664,10 +710,12 @@ class EmformerAttention(nn.Module):
|
|||||||
if left_context_key is not None and left_context_val is not None:
|
if left_context_key is not None and left_context_val is not None:
|
||||||
# inference mode
|
# inference mode
|
||||||
L = left_context_key.size(0)
|
L = left_context_key.size(0)
|
||||||
assert PE == L + 2 * U - 1
|
tot_left_length = M * self.chunk_length if M > 0 else L
|
||||||
|
assert tot_left_length >= L
|
||||||
|
assert PE == tot_left_length + 2 * U + self.right_context_length - 1
|
||||||
else:
|
else:
|
||||||
# training and validation mode
|
# training and validation mode
|
||||||
assert PE == 2 * U - 1
|
assert PE == 2 * U + self.right_context_length - 1
|
||||||
pos_emb = (
|
pos_emb = (
|
||||||
self.linear_pos(pos_emb)
|
self.linear_pos(pos_emb)
|
||||||
.view(PE, self.nhead, self.head_dim)
|
.view(PE, self.nhead, self.head_dim)
|
||||||
@ -679,13 +727,49 @@ class EmformerAttention(nn.Module):
|
|||||||
) # (B, nhead, U, PE)
|
) # (B, nhead, U, PE)
|
||||||
# rel-shift operation
|
# rel-shift operation
|
||||||
matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
|
matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
|
||||||
# (B, nhead, U, U) for training and validation mode;
|
# (B, nhead, U, U + right_context_length) for training and validation mode; # noqa
|
||||||
# (B, nhead, U, L + U) for inference mode.
|
# (B, nhead, U, tot_left_length + U + right_context_length) for inference mode. # noqa
|
||||||
matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
|
matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
|
||||||
B * self.nhead, U, -1
|
B * self.nhead, U, -1
|
||||||
)
|
)
|
||||||
matrix_bd = torch.zeros_like(matrix_ac)
|
matrix_bd = torch.zeros_like(matrix_ac)
|
||||||
matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance
|
if left_context_key is not None and left_context_val is not None:
|
||||||
|
# inference mode
|
||||||
|
# key: [memory, right context, left context, utterance]
|
||||||
|
# for memory
|
||||||
|
if M > 0:
|
||||||
|
matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d(
|
||||||
|
matrix_bd_utterance[:, :, :tot_left_length].unsqueeze(1),
|
||||||
|
kernel_size=(1, self.chunk_length),
|
||||||
|
stride=(1, self.chunk_length),
|
||||||
|
).squeeze(1)
|
||||||
|
# for right_context
|
||||||
|
if R > 0:
|
||||||
|
matrix_bd[:, R : R + U, M : M + R] = matrix_bd_utterance[
|
||||||
|
:, :, tot_left_length + U :
|
||||||
|
]
|
||||||
|
# for left_context and utterance
|
||||||
|
matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance[
|
||||||
|
:, :, tot_left_length - L : tot_left_length + U
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# training and validation mode
|
||||||
|
# key: [memory, right context, utterance]
|
||||||
|
# for memory
|
||||||
|
if M > 0:
|
||||||
|
matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d(
|
||||||
|
matrix_bd_utterance[:, :, :U].unsqueeze(1),
|
||||||
|
kernel_size=(1, self.chunk_length),
|
||||||
|
stride=(1, self.chunk_length),
|
||||||
|
ceil_mode=True,
|
||||||
|
).squeeze(1)[:, :, :-1]
|
||||||
|
# for right_context
|
||||||
|
if R > 0:
|
||||||
|
matrix_bd[
|
||||||
|
:, R : R + U, M : M + R
|
||||||
|
] = self._get_right_context_part(matrix_bd_utterance)
|
||||||
|
# for utterance
|
||||||
|
matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance[:, :, :U]
|
||||||
|
|
||||||
attention_weights = matrix_ac + matrix_bd
|
attention_weights = matrix_ac + matrix_bd
|
||||||
|
|
||||||
@ -717,7 +801,29 @@ class EmformerAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
output_memory = torch.clamp(output_memory, min=-10, max=10)
|
output_memory = torch.clamp(output_memory, min=-10, max=10)
|
||||||
|
|
||||||
return output_right_context_utterance, output_memory, key, value
|
if need_weights:
|
||||||
|
# average over attention heads
|
||||||
|
attention_probs = attention_probs.reshape(B, self.nhead, Q, KV)
|
||||||
|
attention_probs = attention_probs.sum(dim=1) / self.nhead
|
||||||
|
probs_memory = attention_probs[:, R : R + U, :M].sum(dim=2)
|
||||||
|
probs_frames = attention_probs[:, R : R + U, M:].sum(dim=2)
|
||||||
|
return (
|
||||||
|
output_right_context_utterance,
|
||||||
|
output_memory,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
probs_memory,
|
||||||
|
probs_frames,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
output_right_context_utterance,
|
||||||
|
output_memory,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -766,10 +872,11 @@ class EmformerAttention(nn.Module):
|
|||||||
s_i (in query) -> l_i, c_i, r_i (in key).
|
s_i (in query) -> l_i, c_i, r_i (in key).
|
||||||
|
|
||||||
Relative positional encoding is applied on the part of attention between
|
Relative positional encoding is applied on the part of attention between
|
||||||
utterance (in query) and utterance (in key). Actually, it is applied on
|
[utterance] (in query) and [memory, right_context, utterance] (in key).
|
||||||
the part of attention between each chunk (in query) and itself as well
|
Actually, it is applied on the part of attention between each chunk
|
||||||
as its left context (in key), after applying the mask:
|
(in query) and itself, its memory vectors, left context, and right
|
||||||
c_i -> l_i, c_i.
|
context (in key), after applying the mask:
|
||||||
|
c_i (in query) -> l_i, c_i, r_i, m_i (in key).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
utterance (torch.Tensor):
|
utterance (torch.Tensor):
|
||||||
@ -791,18 +898,23 @@ class EmformerAttention(nn.Module):
|
|||||||
attention, with shape (Q, KV).
|
attention, with shape (Q, KV).
|
||||||
pos_emb (torch.Tensor):
|
pos_emb (torch.Tensor):
|
||||||
Position encoding embedding, with shape (PE, D).
|
Position encoding embedding, with shape (PE, D).
|
||||||
where PE = 2 * U - 1.
|
where PE = 2 * U + right_context_length - 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing 2 tensors:
|
A tuple containing 2 tensors:
|
||||||
- output of right context and utterance, with shape (R + U, B, D).
|
- output of right context and utterance, with shape (R + U, B, D).
|
||||||
- memory output, with shape (M, B, D), where M = S - 1 or M = 0.
|
- memory output, with shape (M, B, D), where M = S - 1 or M = 0.
|
||||||
|
- summary of attention weights on memory, with shape (B, U).
|
||||||
|
- summary of attention weights on left context, utterance, and
|
||||||
|
right context, with shape (B, U).
|
||||||
"""
|
"""
|
||||||
(
|
(
|
||||||
output_right_context_utterance,
|
output_right_context_utterance,
|
||||||
output_memory,
|
output_memory,
|
||||||
_,
|
_,
|
||||||
_,
|
_,
|
||||||
|
probs_memory,
|
||||||
|
probs_frames,
|
||||||
) = self._forward_impl(
|
) = self._forward_impl(
|
||||||
utterance,
|
utterance,
|
||||||
lengths,
|
lengths,
|
||||||
@ -811,8 +923,14 @@ class EmformerAttention(nn.Module):
|
|||||||
memory,
|
memory,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
|
need_weights=True,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
output_right_context_utterance,
|
||||||
|
output_memory[:-1],
|
||||||
|
probs_memory,
|
||||||
|
probs_frames,
|
||||||
)
|
)
|
||||||
return output_right_context_utterance, output_memory[:-1]
|
|
||||||
|
|
||||||
def infer(
|
def infer(
|
||||||
self,
|
self,
|
||||||
@ -849,9 +967,9 @@ class EmformerAttention(nn.Module):
|
|||||||
left context, chunk, right context, memory vectors (in key);
|
left context, chunk, right context, memory vectors (in key);
|
||||||
summary (in query) -> left context, chunk, right context (in key).
|
summary (in query) -> left context, chunk, right context (in key).
|
||||||
|
|
||||||
Relative positional encoding is applied on the part of attention between
|
Relative positional encoding is applied on the part of attention:
|
||||||
chunk (in query) and chunk itself as well as its left context (in key):
|
chunk (in query) ->
|
||||||
chunk (in query) -> left context, chunk (in key).
|
left context, chunk, right context, memory vectors (in key);
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
utterance (torch.Tensor):
|
utterance (torch.Tensor):
|
||||||
@ -874,7 +992,7 @@ class EmformerAttention(nn.Module):
|
|||||||
with shape (L, B, D), where L <= left_context_length.
|
with shape (L, B, D), where L <= left_context_length.
|
||||||
pos_emb (torch.Tensor):
|
pos_emb (torch.Tensor):
|
||||||
Position encoding embedding, with shape (PE, D),
|
Position encoding embedding, with shape (PE, D),
|
||||||
where PE = L + 2 * U - 1.
|
where PE = M * chunk_length + 2 * U - 1 if M > 0 else L + 2 * U - 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing 4 tensors:
|
A tuple containing 4 tensors:
|
||||||
@ -905,6 +1023,8 @@ class EmformerAttention(nn.Module):
|
|||||||
output_memory,
|
output_memory,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
) = self._forward_impl(
|
) = self._forward_impl(
|
||||||
utterance,
|
utterance,
|
||||||
lengths,
|
lengths,
|
||||||
@ -974,6 +1094,8 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
self.attention = EmformerAttention(
|
self.attention = EmformerAttention(
|
||||||
embed_dim=d_model,
|
embed_dim=d_model,
|
||||||
nhead=nhead,
|
nhead=nhead,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
right_context_length=right_context_length,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
tanh_on_mem=tanh_on_mem,
|
tanh_on_mem=tanh_on_mem,
|
||||||
negative_inf=negative_inf,
|
negative_inf=negative_inf,
|
||||||
@ -1018,6 +1140,7 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
self.layer_dropout = layer_dropout
|
self.layer_dropout = layer_dropout
|
||||||
self.left_context_length = left_context_length
|
self.left_context_length = left_context_length
|
||||||
self.chunk_length = chunk_length
|
self.chunk_length = chunk_length
|
||||||
|
self.right_context_length = right_context_length
|
||||||
self.max_memory_size = max_memory_size
|
self.max_memory_size = max_memory_size
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.use_memory = max_memory_size > 0
|
self.use_memory = max_memory_size > 0
|
||||||
@ -1140,7 +1263,7 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
summary = torch.empty(0).to(
|
summary = torch.empty(0).to(
|
||||||
dtype=utterance.dtype, device=utterance.device
|
dtype=utterance.dtype, device=utterance.device
|
||||||
)
|
)
|
||||||
output_right_context_utterance, output_memory = self.attention(
|
output_right_context_utterance, output_memory, _, _ = self.attention(
|
||||||
utterance=utterance,
|
utterance=utterance,
|
||||||
lengths=lengths,
|
lengths=lengths,
|
||||||
right_context=right_context,
|
right_context=right_context,
|
||||||
@ -1190,14 +1313,26 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
summary = torch.empty(0).to(
|
summary = torch.empty(0).to(
|
||||||
dtype=utterance.dtype, device=utterance.device
|
dtype=utterance.dtype, device=utterance.device
|
||||||
)
|
)
|
||||||
# pos_emb is of shape [PE, D], where PE = L + 2 * U - 1,
|
|
||||||
# for query of [utterance] (i), key-value [left_context, utterance] (j),
|
|
||||||
# the max relative distance i - j is L + U - 1
|
|
||||||
# the min relative distance i - j is -(U - 1)
|
|
||||||
L = left_context_key.size(0) # L <= left_context_length
|
|
||||||
U = utterance.size(0)
|
U = utterance.size(0)
|
||||||
PE = L + 2 * U - 1
|
# pos_emb is of shape [PE, D], where PE = M * chunk_length + 2 * U - 1,
|
||||||
tot_PE = self.left_context_length + 2 * U - 1
|
# for query of [utterance] (i), key-value [memory vectors, left context, utterance, right context] (j) # noqa
|
||||||
|
# the max relative distance i - j is M * chunk_length + U - 1
|
||||||
|
# the min relative distance i - j is -(U + right_context_length - 1)
|
||||||
|
M = pre_memory.size(0) # M <= max_memory_size
|
||||||
|
if self.max_memory_size > 0:
|
||||||
|
PE = M * self.chunk_length + 2 * U + self.right_context_length - 1
|
||||||
|
tot_PE = (
|
||||||
|
self.max_memory_size * self.chunk_length
|
||||||
|
+ 2 * U
|
||||||
|
+ self.right_context_length
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
L = left_context_key.size(0)
|
||||||
|
PE = L + 2 * U + self.right_context_length - 1
|
||||||
|
tot_PE = (
|
||||||
|
self.left_context_length + 2 * U + self.right_context_length - 1
|
||||||
|
)
|
||||||
assert pos_emb.size(0) == tot_PE
|
assert pos_emb.size(0) == tot_PE
|
||||||
pos_emb = pos_emb[tot_PE - PE :]
|
pos_emb = pos_emb[tot_PE - PE :]
|
||||||
(
|
(
|
||||||
@ -1661,7 +1796,9 @@ class EmformerEncoder(nn.Module):
|
|||||||
right_context at the end.
|
right_context at the end.
|
||||||
"""
|
"""
|
||||||
U = x.size(0) - self.right_context_length
|
U = x.size(0) - self.right_context_length
|
||||||
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
|
x, pos_emb = self.encoder_pos(
|
||||||
|
x, pos_len=U, neg_len=U + self.right_context_length
|
||||||
|
)
|
||||||
|
|
||||||
right_context = self._gen_right_context(x)
|
right_context = self._gen_right_context(x)
|
||||||
utterance = x[:U]
|
utterance = x[:U]
|
||||||
@ -1734,8 +1871,12 @@ class EmformerEncoder(nn.Module):
|
|||||||
f"for dimension 1 of x, but got {x.size(1)}."
|
f"for dimension 1 of x, but got {x.size(1)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
pos_len = self.chunk_length + self.left_context_length
|
pos_len = (
|
||||||
neg_len = self.chunk_length
|
self.max_memory_size * self.chunk_length + self.chunk_length
|
||||||
|
if self.max_memory_size > 0
|
||||||
|
else self.left_context_length + self.chunk_length
|
||||||
|
)
|
||||||
|
neg_len = self.chunk_length + self.right_context_length
|
||||||
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
|
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
|
||||||
|
|
||||||
right_context_start_idx = x.size(0) - self.right_context_length
|
right_context_start_idx = x.size(0) - self.right_context_length
|
||||||
@ -1807,6 +1948,13 @@ class Emformer(EncoderInterface):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"right_context_length must be 0 or a mutiple of 4."
|
"right_context_length must be 0 or a mutiple of 4."
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
max_memory_size > 0
|
||||||
|
and max_memory_size * chunk_length < left_context_length
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"max_memory_size * chunk_length can not be less than left_context_length" # noqa
|
||||||
|
)
|
||||||
|
|
||||||
# self.encoder_embed converts the input of shape (N, T, num_features)
|
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||||
# to the shape (N, T//subsampling_factor, d_model).
|
# to the shape (N, T//subsampling_factor, d_model).
|
||||||
|
@ -22,7 +22,12 @@ def test_emformer_attention_forward():
|
|||||||
num_chunks = 3
|
num_chunks = 3
|
||||||
U = num_chunks * chunk_length
|
U = num_chunks * chunk_length
|
||||||
R = num_chunks * right_context_length
|
R = num_chunks * right_context_length
|
||||||
attention = EmformerAttention(embed_dim=D, nhead=8)
|
attention = EmformerAttention(
|
||||||
|
embed_dim=D,
|
||||||
|
nhead=8,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
right_context_length=right_context_length,
|
||||||
|
)
|
||||||
|
|
||||||
for use_memory in [True, False]:
|
for use_memory in [True, False]:
|
||||||
if use_memory:
|
if use_memory:
|
||||||
@ -39,10 +44,15 @@ def test_emformer_attention_forward():
|
|||||||
summary = torch.randn(S, B, D)
|
summary = torch.randn(S, B, D)
|
||||||
memory = torch.randn(M, B, D)
|
memory = torch.randn(M, B, D)
|
||||||
attention_mask = torch.rand(Q, KV) >= 0.5
|
attention_mask = torch.rand(Q, KV) >= 0.5
|
||||||
PE = 2 * U - 1
|
PE = 2 * U + right_context_length - 1
|
||||||
pos_emb = torch.randn(PE, D)
|
pos_emb = torch.randn(PE, D)
|
||||||
|
|
||||||
output_right_context_utterance, output_memory = attention(
|
(
|
||||||
|
output_right_context_utterance,
|
||||||
|
output_memory,
|
||||||
|
probs_memory,
|
||||||
|
probs_frames,
|
||||||
|
) = attention(
|
||||||
utterance,
|
utterance,
|
||||||
lengths,
|
lengths,
|
||||||
right_context,
|
right_context,
|
||||||
@ -53,16 +63,26 @@ def test_emformer_attention_forward():
|
|||||||
)
|
)
|
||||||
assert output_right_context_utterance.shape == (R + U, B, D)
|
assert output_right_context_utterance.shape == (R + U, B, D)
|
||||||
assert output_memory.shape == (M, B, D)
|
assert output_memory.shape == (M, B, D)
|
||||||
|
assert probs_memory.shape == (B, U)
|
||||||
|
assert probs_frames.shape == (B, U)
|
||||||
|
|
||||||
|
|
||||||
def test_emformer_attention_infer():
|
def test_emformer_attention_infer():
|
||||||
from emformer import EmformerAttention
|
from emformer import EmformerAttention
|
||||||
|
|
||||||
B, D = 2, 256
|
B, D = 2, 256
|
||||||
U = 4
|
chunk_length = 4
|
||||||
R = 2
|
right_context_length = 2
|
||||||
|
num_chunks = 1
|
||||||
|
U = chunk_length * num_chunks
|
||||||
|
R = right_context_length * num_chunks
|
||||||
L = 3
|
L = 3
|
||||||
attention = EmformerAttention(embed_dim=D, nhead=8)
|
attention = EmformerAttention(
|
||||||
|
embed_dim=D,
|
||||||
|
nhead=8,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
right_context_length=right_context_length,
|
||||||
|
)
|
||||||
|
|
||||||
for use_memory in [True, False]:
|
for use_memory in [True, False]:
|
||||||
if use_memory:
|
if use_memory:
|
||||||
@ -78,7 +98,12 @@ def test_emformer_attention_infer():
|
|||||||
memory = torch.randn(M, B, D)
|
memory = torch.randn(M, B, D)
|
||||||
left_context_key = torch.randn(L, B, D)
|
left_context_key = torch.randn(L, B, D)
|
||||||
left_context_val = torch.randn(L, B, D)
|
left_context_val = torch.randn(L, B, D)
|
||||||
PE = L + 2 * U - 1
|
PE = (
|
||||||
|
2 * U
|
||||||
|
+ right_context_length
|
||||||
|
- 1
|
||||||
|
+ (M * chunk_length if M > 0 else L)
|
||||||
|
)
|
||||||
pos_emb = torch.randn(PE, D)
|
pos_emb = torch.randn(PE, D)
|
||||||
|
|
||||||
(
|
(
|
||||||
@ -197,7 +222,7 @@ def test_emformer_encoder_layer_forward():
|
|||||||
right_context = torch.randn(R, B, D)
|
right_context = torch.randn(R, B, D)
|
||||||
memory = torch.randn(M, B, D)
|
memory = torch.randn(M, B, D)
|
||||||
attention_mask = torch.rand(Q, KV) >= 0.5
|
attention_mask = torch.rand(Q, KV) >= 0.5
|
||||||
PE = 2 * U - 1
|
PE = 2 * U + right_context_length - 1
|
||||||
pos_emb = torch.randn(PE, D)
|
pos_emb = torch.randn(PE, D)
|
||||||
|
|
||||||
output_utterance, output_right_context, output_memory = layer(
|
output_utterance, output_right_context, output_memory = layer(
|
||||||
@ -227,8 +252,10 @@ def test_emformer_encoder_layer_infer():
|
|||||||
|
|
||||||
for use_memory in [True, False]:
|
for use_memory in [True, False]:
|
||||||
if use_memory:
|
if use_memory:
|
||||||
M = 3
|
max_memory_size = 3
|
||||||
|
M = 1
|
||||||
else:
|
else:
|
||||||
|
max_memory_size = 0
|
||||||
M = 0
|
M = 0
|
||||||
|
|
||||||
layer = EmformerEncoderLayer(
|
layer = EmformerEncoderLayer(
|
||||||
@ -239,7 +266,7 @@ def test_emformer_encoder_layer_infer():
|
|||||||
cnn_module_kernel=kernel_size,
|
cnn_module_kernel=kernel_size,
|
||||||
left_context_length=left_context_length,
|
left_context_length=left_context_length,
|
||||||
right_context_length=right_context_length,
|
right_context_length=right_context_length,
|
||||||
max_memory_size=M,
|
max_memory_size=max_memory_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
utterance = torch.randn(U, B, D)
|
utterance = torch.randn(U, B, D)
|
||||||
@ -248,7 +275,16 @@ def test_emformer_encoder_layer_infer():
|
|||||||
right_context = torch.randn(R, B, D)
|
right_context = torch.randn(R, B, D)
|
||||||
memory = torch.randn(M, B, D)
|
memory = torch.randn(M, B, D)
|
||||||
state = None
|
state = None
|
||||||
PE = left_context_length + 2 * U - 1
|
PE = (
|
||||||
|
2 * U
|
||||||
|
+ right_context_length
|
||||||
|
- 1
|
||||||
|
+ (
|
||||||
|
max_memory_size * chunk_length
|
||||||
|
if max_memory_size > 0
|
||||||
|
else left_context_length
|
||||||
|
)
|
||||||
|
)
|
||||||
pos_emb = torch.randn(PE, D)
|
pos_emb = torch.randn(PE, D)
|
||||||
conv_cache = None
|
conv_cache = None
|
||||||
(
|
(
|
||||||
@ -273,7 +309,7 @@ def test_emformer_encoder_layer_infer():
|
|||||||
else:
|
else:
|
||||||
assert output_memory.shape == (0, B, D)
|
assert output_memory.shape == (0, B, D)
|
||||||
assert len(output_state) == 4
|
assert len(output_state) == 4
|
||||||
assert output_state[0].shape == (M, B, D)
|
assert output_state[0].shape == (max_memory_size, B, D)
|
||||||
assert output_state[1].shape == (left_context_length, B, D)
|
assert output_state[1].shape == (left_context_length, B, D)
|
||||||
assert output_state[2].shape == (left_context_length, B, D)
|
assert output_state[2].shape == (left_context_length, B, D)
|
||||||
assert output_state[3].shape == (1, B)
|
assert output_state[3].shape == (1, B)
|
||||||
@ -334,9 +370,9 @@ def test_emformer_encoder_infer():
|
|||||||
|
|
||||||
for use_memory in [True, False]:
|
for use_memory in [True, False]:
|
||||||
if use_memory:
|
if use_memory:
|
||||||
M = 3
|
max_memory_size = 3
|
||||||
else:
|
else:
|
||||||
M = 0
|
max_memory_size = 0
|
||||||
|
|
||||||
encoder = EmformerEncoder(
|
encoder = EmformerEncoder(
|
||||||
chunk_length=chunk_length,
|
chunk_length=chunk_length,
|
||||||
@ -346,7 +382,7 @@ def test_emformer_encoder_infer():
|
|||||||
cnn_module_kernel=kernel_size,
|
cnn_module_kernel=kernel_size,
|
||||||
left_context_length=left_context_length,
|
left_context_length=left_context_length,
|
||||||
right_context_length=right_context_length,
|
right_context_length=right_context_length,
|
||||||
max_memory_size=M,
|
max_memory_size=max_memory_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
states = None
|
states = None
|
||||||
@ -368,7 +404,7 @@ def test_emformer_encoder_infer():
|
|||||||
assert len(states) == num_encoder_layers
|
assert len(states) == num_encoder_layers
|
||||||
for state in states:
|
for state in states:
|
||||||
assert len(state) == 4
|
assert len(state) == 4
|
||||||
assert state[0].shape == (M, B, D)
|
assert state[0].shape == (max_memory_size, B, D)
|
||||||
assert state[1].shape == (left_context_length, B, D)
|
assert state[1].shape == (left_context_length, B, D)
|
||||||
assert state[2].shape == (left_context_length, B, D)
|
assert state[2].shape == (left_context_length, B, D)
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
@ -391,7 +427,7 @@ def test_emformer_encoder_forward_infer_consistency():
|
|||||||
kernel_size = 31
|
kernel_size = 31
|
||||||
memory_sizes = [0, 3]
|
memory_sizes = [0, 3]
|
||||||
|
|
||||||
for M in memory_sizes:
|
for max_memory_size in memory_sizes:
|
||||||
encoder = EmformerEncoder(
|
encoder = EmformerEncoder(
|
||||||
chunk_length=chunk_length,
|
chunk_length=chunk_length,
|
||||||
d_model=D,
|
d_model=D,
|
||||||
@ -400,7 +436,7 @@ def test_emformer_encoder_forward_infer_consistency():
|
|||||||
cnn_module_kernel=kernel_size,
|
cnn_module_kernel=kernel_size,
|
||||||
left_context_length=left_context_length,
|
left_context_length=left_context_length,
|
||||||
right_context_length=right_context_length,
|
right_context_length=right_context_length,
|
||||||
max_memory_size=M,
|
max_memory_size=max_memory_size,
|
||||||
)
|
)
|
||||||
encoder.eval()
|
encoder.eval()
|
||||||
|
|
||||||
@ -449,9 +485,9 @@ def test_emformer_forward():
|
|||||||
|
|
||||||
for use_memory in [True, False]:
|
for use_memory in [True, False]:
|
||||||
if use_memory:
|
if use_memory:
|
||||||
M = 3
|
max_memory_size = 3
|
||||||
else:
|
else:
|
||||||
M = 0
|
max_memory_size = 0
|
||||||
model = Emformer(
|
model = Emformer(
|
||||||
num_features=num_features,
|
num_features=num_features,
|
||||||
chunk_length=chunk_length,
|
chunk_length=chunk_length,
|
||||||
@ -460,7 +496,7 @@ def test_emformer_forward():
|
|||||||
cnn_module_kernel=kernel_size,
|
cnn_module_kernel=kernel_size,
|
||||||
left_context_length=left_context_length,
|
left_context_length=left_context_length,
|
||||||
right_context_length=right_context_length,
|
right_context_length=right_context_length,
|
||||||
max_memory_size=M,
|
max_memory_size=max_memory_size,
|
||||||
)
|
)
|
||||||
x = torch.randn(B, U + right_context_length + 3, num_features)
|
x = torch.randn(B, U + right_context_length + 3, num_features)
|
||||||
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
|
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
|
||||||
@ -481,7 +517,7 @@ def test_emformer_infer():
|
|||||||
num_features = 80
|
num_features = 80
|
||||||
chunk_length = 8
|
chunk_length = 8
|
||||||
U = chunk_length
|
U = chunk_length
|
||||||
left_context_length, right_context_length = 128, 4
|
left_context_length, right_context_length = 32, 4
|
||||||
B, D = 2, 256
|
B, D = 2, 256
|
||||||
num_chunks = 3
|
num_chunks = 3
|
||||||
num_encoder_layers = 2
|
num_encoder_layers = 2
|
||||||
@ -489,9 +525,9 @@ def test_emformer_infer():
|
|||||||
|
|
||||||
for use_memory in [True, False]:
|
for use_memory in [True, False]:
|
||||||
if use_memory:
|
if use_memory:
|
||||||
M = 3
|
max_memory_size = 32
|
||||||
else:
|
else:
|
||||||
M = 0
|
max_memory_size = 0
|
||||||
model = Emformer(
|
model = Emformer(
|
||||||
num_features=num_features,
|
num_features=num_features,
|
||||||
chunk_length=chunk_length,
|
chunk_length=chunk_length,
|
||||||
@ -501,7 +537,7 @@ def test_emformer_infer():
|
|||||||
cnn_module_kernel=kernel_size,
|
cnn_module_kernel=kernel_size,
|
||||||
left_context_length=left_context_length,
|
left_context_length=left_context_length,
|
||||||
right_context_length=right_context_length,
|
right_context_length=right_context_length,
|
||||||
max_memory_size=M,
|
max_memory_size=max_memory_size,
|
||||||
)
|
)
|
||||||
states = None
|
states = None
|
||||||
conv_caches = None
|
conv_caches = None
|
||||||
@ -523,7 +559,7 @@ def test_emformer_infer():
|
|||||||
assert len(states) == num_encoder_layers
|
assert len(states) == num_encoder_layers
|
||||||
for state in states:
|
for state in states:
|
||||||
assert len(state) == 4
|
assert len(state) == 4
|
||||||
assert state[0].shape == (M, B, D)
|
assert state[0].shape == (max_memory_size, B, D)
|
||||||
assert state[1].shape == (left_context_length // 4, B, D)
|
assert state[1].shape == (left_context_length // 4, B, D)
|
||||||
assert state[2].shape == (left_context_length // 4, B, D)
|
assert state[2].shape == (left_context_length // 4, B, D)
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user