mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
add pos enc on memory, left context, chunk, and right context (in key)
This commit is contained in:
parent
f94ad976c6
commit
19a8700301
@ -450,6 +450,10 @@ class EmformerAttention(nn.Module):
|
||||
Embedding dimension.
|
||||
nhead (int):
|
||||
Number of attention heads in each Emformer layer.
|
||||
chunk_length (int):
|
||||
Length of each input chunk.
|
||||
right_context_length (int):
|
||||
Length of right context.
|
||||
dropout (float, optional):
|
||||
Dropout probability. (Default: 0.0)
|
||||
tanh_on_mem (bool, optional):
|
||||
@ -462,6 +466,8 @@ class EmformerAttention(nn.Module):
|
||||
self,
|
||||
embed_dim: int,
|
||||
nhead: int,
|
||||
chunk_length: int,
|
||||
right_context_length: int,
|
||||
dropout: float = 0.0,
|
||||
tanh_on_mem: bool = False,
|
||||
negative_inf: float = -1e8,
|
||||
@ -479,6 +485,8 @@ class EmformerAttention(nn.Module):
|
||||
self.tanh_on_mem = tanh_on_mem
|
||||
self.negative_inf = negative_inf
|
||||
self.head_dim = embed_dim // nhead
|
||||
self.chunk_length = chunk_length
|
||||
self.right_context_length = right_context_length
|
||||
self.dropout = dropout
|
||||
|
||||
self.emb_to_key_value = ScaledLinear(
|
||||
@ -572,13 +580,16 @@ class EmformerAttention(nn.Module):
|
||||
Args:
|
||||
x: Input tensor, of shape (B, nhead, U, PE).
|
||||
U is the length of query vector.
|
||||
For training and validation mode, PE = 2 * U - 1;
|
||||
for inference mode, PE = L + 2 * U - 1.
|
||||
For training and validation mode,
|
||||
PE = 2 * U + right_context_length - 1.
|
||||
For inference mode,
|
||||
PE = tot_left_length + 2 * U + right_context_length - 1,
|
||||
where tot_left_length = M * chunk_length.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (B, nhead, U, out_len).
|
||||
For non-infer mode, out_len = U;
|
||||
for infer mode, out_len = L + U.
|
||||
For training and validation mode, out_len = U + right_context_length.
|
||||
For inference mode, out_len = tot_left_length + U + right_context_length. # noqa
|
||||
"""
|
||||
B, nhead, U, PE = x.size()
|
||||
B_stride = x.stride(0)
|
||||
@ -592,6 +603,33 @@ class EmformerAttention(nn.Module):
|
||||
storage_offset=PE_stride * (U - 1),
|
||||
)
|
||||
|
||||
def _get_right_context_part(
|
||||
self, matrix_bd_utterance: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
matrix_bd_utterance:
|
||||
(B * nhead, U, U + right_context_length)
|
||||
|
||||
Returns:
|
||||
A tensor of shape (B * nhead, U, R),
|
||||
where R = num_chunks * right_context_length.
|
||||
"""
|
||||
assert self.right_context_length > 0
|
||||
U = matrix_bd_utterance.size(1)
|
||||
num_chunks = math.ceil(U / self.chunk_length)
|
||||
right_context_blocks = []
|
||||
for i in range(num_chunks - 1):
|
||||
start_idx = (i + 1) * self.chunk_length
|
||||
end_idx = start_idx + self.right_context_length
|
||||
right_context_blocks.append(
|
||||
matrix_bd_utterance[:, :, start_idx:end_idx]
|
||||
)
|
||||
right_context_blocks.append(
|
||||
matrix_bd_utterance[:, :, -self.right_context_length :]
|
||||
)
|
||||
return torch.cat(right_context_blocks, dim=2)
|
||||
|
||||
def _forward_impl(
|
||||
self,
|
||||
utterance: torch.Tensor,
|
||||
@ -603,7 +641,15 @@ class EmformerAttention(nn.Module):
|
||||
pos_emb: torch.Tensor,
|
||||
left_context_key: Optional[torch.Tensor] = None,
|
||||
left_context_val: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
need_weights=False,
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
"""Underlying chunk-wise attention implementation."""
|
||||
U, B, _ = utterance.size()
|
||||
R = right_context.size(0)
|
||||
@ -664,10 +710,12 @@ class EmformerAttention(nn.Module):
|
||||
if left_context_key is not None and left_context_val is not None:
|
||||
# inference mode
|
||||
L = left_context_key.size(0)
|
||||
assert PE == L + 2 * U - 1
|
||||
tot_left_length = M * self.chunk_length if M > 0 else L
|
||||
assert tot_left_length >= L
|
||||
assert PE == tot_left_length + 2 * U + self.right_context_length - 1
|
||||
else:
|
||||
# training and validation mode
|
||||
assert PE == 2 * U - 1
|
||||
assert PE == 2 * U + self.right_context_length - 1
|
||||
pos_emb = (
|
||||
self.linear_pos(pos_emb)
|
||||
.view(PE, self.nhead, self.head_dim)
|
||||
@ -679,13 +727,49 @@ class EmformerAttention(nn.Module):
|
||||
) # (B, nhead, U, PE)
|
||||
# rel-shift operation
|
||||
matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
|
||||
# (B, nhead, U, U) for training and validation mode;
|
||||
# (B, nhead, U, L + U) for inference mode.
|
||||
# (B, nhead, U, U + right_context_length) for training and validation mode; # noqa
|
||||
# (B, nhead, U, tot_left_length + U + right_context_length) for inference mode. # noqa
|
||||
matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
|
||||
B * self.nhead, U, -1
|
||||
)
|
||||
matrix_bd = torch.zeros_like(matrix_ac)
|
||||
matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance
|
||||
if left_context_key is not None and left_context_val is not None:
|
||||
# inference mode
|
||||
# key: [memory, right context, left context, utterance]
|
||||
# for memory
|
||||
if M > 0:
|
||||
matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d(
|
||||
matrix_bd_utterance[:, :, :tot_left_length].unsqueeze(1),
|
||||
kernel_size=(1, self.chunk_length),
|
||||
stride=(1, self.chunk_length),
|
||||
).squeeze(1)
|
||||
# for right_context
|
||||
if R > 0:
|
||||
matrix_bd[:, R : R + U, M : M + R] = matrix_bd_utterance[
|
||||
:, :, tot_left_length + U :
|
||||
]
|
||||
# for left_context and utterance
|
||||
matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance[
|
||||
:, :, tot_left_length - L : tot_left_length + U
|
||||
]
|
||||
else:
|
||||
# training and validation mode
|
||||
# key: [memory, right context, utterance]
|
||||
# for memory
|
||||
if M > 0:
|
||||
matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d(
|
||||
matrix_bd_utterance[:, :, :U].unsqueeze(1),
|
||||
kernel_size=(1, self.chunk_length),
|
||||
stride=(1, self.chunk_length),
|
||||
ceil_mode=True,
|
||||
).squeeze(1)[:, :, :-1]
|
||||
# for right_context
|
||||
if R > 0:
|
||||
matrix_bd[
|
||||
:, R : R + U, M : M + R
|
||||
] = self._get_right_context_part(matrix_bd_utterance)
|
||||
# for utterance
|
||||
matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance[:, :, :U]
|
||||
|
||||
attention_weights = matrix_ac + matrix_bd
|
||||
|
||||
@ -717,7 +801,29 @@ class EmformerAttention(nn.Module):
|
||||
else:
|
||||
output_memory = torch.clamp(output_memory, min=-10, max=10)
|
||||
|
||||
return output_right_context_utterance, output_memory, key, value
|
||||
if need_weights:
|
||||
# average over attention heads
|
||||
attention_probs = attention_probs.reshape(B, self.nhead, Q, KV)
|
||||
attention_probs = attention_probs.sum(dim=1) / self.nhead
|
||||
probs_memory = attention_probs[:, R : R + U, :M].sum(dim=2)
|
||||
probs_frames = attention_probs[:, R : R + U, M:].sum(dim=2)
|
||||
return (
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
key,
|
||||
value,
|
||||
probs_memory,
|
||||
probs_frames,
|
||||
)
|
||||
|
||||
return (
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -766,10 +872,11 @@ class EmformerAttention(nn.Module):
|
||||
s_i (in query) -> l_i, c_i, r_i (in key).
|
||||
|
||||
Relative positional encoding is applied on the part of attention between
|
||||
utterance (in query) and utterance (in key). Actually, it is applied on
|
||||
the part of attention between each chunk (in query) and itself as well
|
||||
as its left context (in key), after applying the mask:
|
||||
c_i -> l_i, c_i.
|
||||
[utterance] (in query) and [memory, right_context, utterance] (in key).
|
||||
Actually, it is applied on the part of attention between each chunk
|
||||
(in query) and itself, its memory vectors, left context, and right
|
||||
context (in key), after applying the mask:
|
||||
c_i (in query) -> l_i, c_i, r_i, m_i (in key).
|
||||
|
||||
Args:
|
||||
utterance (torch.Tensor):
|
||||
@ -791,18 +898,23 @@ class EmformerAttention(nn.Module):
|
||||
attention, with shape (Q, KV).
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
where PE = 2 * U - 1.
|
||||
where PE = 2 * U + right_context_length - 1.
|
||||
|
||||
Returns:
|
||||
A tuple containing 2 tensors:
|
||||
- output of right context and utterance, with shape (R + U, B, D).
|
||||
- memory output, with shape (M, B, D), where M = S - 1 or M = 0.
|
||||
- summary of attention weights on memory, with shape (B, U).
|
||||
- summary of attention weights on left context, utterance, and
|
||||
right context, with shape (B, U).
|
||||
"""
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
_,
|
||||
_,
|
||||
probs_memory,
|
||||
probs_frames,
|
||||
) = self._forward_impl(
|
||||
utterance,
|
||||
lengths,
|
||||
@ -811,8 +923,14 @@ class EmformerAttention(nn.Module):
|
||||
memory,
|
||||
attention_mask,
|
||||
pos_emb,
|
||||
need_weights=True,
|
||||
)
|
||||
return (
|
||||
output_right_context_utterance,
|
||||
output_memory[:-1],
|
||||
probs_memory,
|
||||
probs_frames,
|
||||
)
|
||||
return output_right_context_utterance, output_memory[:-1]
|
||||
|
||||
def infer(
|
||||
self,
|
||||
@ -849,9 +967,9 @@ class EmformerAttention(nn.Module):
|
||||
left context, chunk, right context, memory vectors (in key);
|
||||
summary (in query) -> left context, chunk, right context (in key).
|
||||
|
||||
Relative positional encoding is applied on the part of attention between
|
||||
chunk (in query) and chunk itself as well as its left context (in key):
|
||||
chunk (in query) -> left context, chunk (in key).
|
||||
Relative positional encoding is applied on the part of attention:
|
||||
chunk (in query) ->
|
||||
left context, chunk, right context, memory vectors (in key);
|
||||
|
||||
Args:
|
||||
utterance (torch.Tensor):
|
||||
@ -874,7 +992,7 @@ class EmformerAttention(nn.Module):
|
||||
with shape (L, B, D), where L <= left_context_length.
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D),
|
||||
where PE = L + 2 * U - 1.
|
||||
where PE = M * chunk_length + 2 * U - 1 if M > 0 else L + 2 * U - 1.
|
||||
|
||||
Returns:
|
||||
A tuple containing 4 tensors:
|
||||
@ -905,6 +1023,8 @@ class EmformerAttention(nn.Module):
|
||||
output_memory,
|
||||
key,
|
||||
value,
|
||||
_,
|
||||
_,
|
||||
) = self._forward_impl(
|
||||
utterance,
|
||||
lengths,
|
||||
@ -974,6 +1094,8 @@ class EmformerEncoderLayer(nn.Module):
|
||||
self.attention = EmformerAttention(
|
||||
embed_dim=d_model,
|
||||
nhead=nhead,
|
||||
chunk_length=chunk_length,
|
||||
right_context_length=right_context_length,
|
||||
dropout=dropout,
|
||||
tanh_on_mem=tanh_on_mem,
|
||||
negative_inf=negative_inf,
|
||||
@ -1018,6 +1140,7 @@ class EmformerEncoderLayer(nn.Module):
|
||||
self.layer_dropout = layer_dropout
|
||||
self.left_context_length = left_context_length
|
||||
self.chunk_length = chunk_length
|
||||
self.right_context_length = right_context_length
|
||||
self.max_memory_size = max_memory_size
|
||||
self.d_model = d_model
|
||||
self.use_memory = max_memory_size > 0
|
||||
@ -1140,7 +1263,7 @@ class EmformerEncoderLayer(nn.Module):
|
||||
summary = torch.empty(0).to(
|
||||
dtype=utterance.dtype, device=utterance.device
|
||||
)
|
||||
output_right_context_utterance, output_memory = self.attention(
|
||||
output_right_context_utterance, output_memory, _, _ = self.attention(
|
||||
utterance=utterance,
|
||||
lengths=lengths,
|
||||
right_context=right_context,
|
||||
@ -1190,14 +1313,26 @@ class EmformerEncoderLayer(nn.Module):
|
||||
summary = torch.empty(0).to(
|
||||
dtype=utterance.dtype, device=utterance.device
|
||||
)
|
||||
# pos_emb is of shape [PE, D], where PE = L + 2 * U - 1,
|
||||
# for query of [utterance] (i), key-value [left_context, utterance] (j),
|
||||
# the max relative distance i - j is L + U - 1
|
||||
# the min relative distance i - j is -(U - 1)
|
||||
L = left_context_key.size(0) # L <= left_context_length
|
||||
U = utterance.size(0)
|
||||
PE = L + 2 * U - 1
|
||||
tot_PE = self.left_context_length + 2 * U - 1
|
||||
# pos_emb is of shape [PE, D], where PE = M * chunk_length + 2 * U - 1,
|
||||
# for query of [utterance] (i), key-value [memory vectors, left context, utterance, right context] (j) # noqa
|
||||
# the max relative distance i - j is M * chunk_length + U - 1
|
||||
# the min relative distance i - j is -(U + right_context_length - 1)
|
||||
M = pre_memory.size(0) # M <= max_memory_size
|
||||
if self.max_memory_size > 0:
|
||||
PE = M * self.chunk_length + 2 * U + self.right_context_length - 1
|
||||
tot_PE = (
|
||||
self.max_memory_size * self.chunk_length
|
||||
+ 2 * U
|
||||
+ self.right_context_length
|
||||
- 1
|
||||
)
|
||||
else:
|
||||
L = left_context_key.size(0)
|
||||
PE = L + 2 * U + self.right_context_length - 1
|
||||
tot_PE = (
|
||||
self.left_context_length + 2 * U + self.right_context_length - 1
|
||||
)
|
||||
assert pos_emb.size(0) == tot_PE
|
||||
pos_emb = pos_emb[tot_PE - PE :]
|
||||
(
|
||||
@ -1661,7 +1796,9 @@ class EmformerEncoder(nn.Module):
|
||||
right_context at the end.
|
||||
"""
|
||||
U = x.size(0) - self.right_context_length
|
||||
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
|
||||
x, pos_emb = self.encoder_pos(
|
||||
x, pos_len=U, neg_len=U + self.right_context_length
|
||||
)
|
||||
|
||||
right_context = self._gen_right_context(x)
|
||||
utterance = x[:U]
|
||||
@ -1734,8 +1871,12 @@ class EmformerEncoder(nn.Module):
|
||||
f"for dimension 1 of x, but got {x.size(1)}."
|
||||
)
|
||||
|
||||
pos_len = self.chunk_length + self.left_context_length
|
||||
neg_len = self.chunk_length
|
||||
pos_len = (
|
||||
self.max_memory_size * self.chunk_length + self.chunk_length
|
||||
if self.max_memory_size > 0
|
||||
else self.left_context_length + self.chunk_length
|
||||
)
|
||||
neg_len = self.chunk_length + self.right_context_length
|
||||
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
|
||||
|
||||
right_context_start_idx = x.size(0) - self.right_context_length
|
||||
@ -1807,6 +1948,13 @@ class Emformer(EncoderInterface):
|
||||
raise NotImplementedError(
|
||||
"right_context_length must be 0 or a mutiple of 4."
|
||||
)
|
||||
if (
|
||||
max_memory_size > 0
|
||||
and max_memory_size * chunk_length < left_context_length
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"max_memory_size * chunk_length can not be less than left_context_length" # noqa
|
||||
)
|
||||
|
||||
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||
# to the shape (N, T//subsampling_factor, d_model).
|
||||
|
@ -22,7 +22,12 @@ def test_emformer_attention_forward():
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
R = num_chunks * right_context_length
|
||||
attention = EmformerAttention(embed_dim=D, nhead=8)
|
||||
attention = EmformerAttention(
|
||||
embed_dim=D,
|
||||
nhead=8,
|
||||
chunk_length=chunk_length,
|
||||
right_context_length=right_context_length,
|
||||
)
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
@ -39,10 +44,15 @@ def test_emformer_attention_forward():
|
||||
summary = torch.randn(S, B, D)
|
||||
memory = torch.randn(M, B, D)
|
||||
attention_mask = torch.rand(Q, KV) >= 0.5
|
||||
PE = 2 * U - 1
|
||||
PE = 2 * U + right_context_length - 1
|
||||
pos_emb = torch.randn(PE, D)
|
||||
|
||||
output_right_context_utterance, output_memory = attention(
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
probs_memory,
|
||||
probs_frames,
|
||||
) = attention(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
@ -53,16 +63,26 @@ def test_emformer_attention_forward():
|
||||
)
|
||||
assert output_right_context_utterance.shape == (R + U, B, D)
|
||||
assert output_memory.shape == (M, B, D)
|
||||
assert probs_memory.shape == (B, U)
|
||||
assert probs_frames.shape == (B, U)
|
||||
|
||||
|
||||
def test_emformer_attention_infer():
|
||||
from emformer import EmformerAttention
|
||||
|
||||
B, D = 2, 256
|
||||
U = 4
|
||||
R = 2
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
num_chunks = 1
|
||||
U = chunk_length * num_chunks
|
||||
R = right_context_length * num_chunks
|
||||
L = 3
|
||||
attention = EmformerAttention(embed_dim=D, nhead=8)
|
||||
attention = EmformerAttention(
|
||||
embed_dim=D,
|
||||
nhead=8,
|
||||
chunk_length=chunk_length,
|
||||
right_context_length=right_context_length,
|
||||
)
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
@ -78,7 +98,12 @@ def test_emformer_attention_infer():
|
||||
memory = torch.randn(M, B, D)
|
||||
left_context_key = torch.randn(L, B, D)
|
||||
left_context_val = torch.randn(L, B, D)
|
||||
PE = L + 2 * U - 1
|
||||
PE = (
|
||||
2 * U
|
||||
+ right_context_length
|
||||
- 1
|
||||
+ (M * chunk_length if M > 0 else L)
|
||||
)
|
||||
pos_emb = torch.randn(PE, D)
|
||||
|
||||
(
|
||||
@ -197,7 +222,7 @@ def test_emformer_encoder_layer_forward():
|
||||
right_context = torch.randn(R, B, D)
|
||||
memory = torch.randn(M, B, D)
|
||||
attention_mask = torch.rand(Q, KV) >= 0.5
|
||||
PE = 2 * U - 1
|
||||
PE = 2 * U + right_context_length - 1
|
||||
pos_emb = torch.randn(PE, D)
|
||||
|
||||
output_utterance, output_right_context, output_memory = layer(
|
||||
@ -227,8 +252,10 @@ def test_emformer_encoder_layer_infer():
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
M = 3
|
||||
max_memory_size = 3
|
||||
M = 1
|
||||
else:
|
||||
max_memory_size = 0
|
||||
M = 0
|
||||
|
||||
layer = EmformerEncoderLayer(
|
||||
@ -239,7 +266,7 @@ def test_emformer_encoder_layer_infer():
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
max_memory_size=max_memory_size,
|
||||
)
|
||||
|
||||
utterance = torch.randn(U, B, D)
|
||||
@ -248,7 +275,16 @@ def test_emformer_encoder_layer_infer():
|
||||
right_context = torch.randn(R, B, D)
|
||||
memory = torch.randn(M, B, D)
|
||||
state = None
|
||||
PE = left_context_length + 2 * U - 1
|
||||
PE = (
|
||||
2 * U
|
||||
+ right_context_length
|
||||
- 1
|
||||
+ (
|
||||
max_memory_size * chunk_length
|
||||
if max_memory_size > 0
|
||||
else left_context_length
|
||||
)
|
||||
)
|
||||
pos_emb = torch.randn(PE, D)
|
||||
conv_cache = None
|
||||
(
|
||||
@ -273,7 +309,7 @@ def test_emformer_encoder_layer_infer():
|
||||
else:
|
||||
assert output_memory.shape == (0, B, D)
|
||||
assert len(output_state) == 4
|
||||
assert output_state[0].shape == (M, B, D)
|
||||
assert output_state[0].shape == (max_memory_size, B, D)
|
||||
assert output_state[1].shape == (left_context_length, B, D)
|
||||
assert output_state[2].shape == (left_context_length, B, D)
|
||||
assert output_state[3].shape == (1, B)
|
||||
@ -334,9 +370,9 @@ def test_emformer_encoder_infer():
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
M = 3
|
||||
max_memory_size = 3
|
||||
else:
|
||||
M = 0
|
||||
max_memory_size = 0
|
||||
|
||||
encoder = EmformerEncoder(
|
||||
chunk_length=chunk_length,
|
||||
@ -346,7 +382,7 @@ def test_emformer_encoder_infer():
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
max_memory_size=max_memory_size,
|
||||
)
|
||||
|
||||
states = None
|
||||
@ -368,7 +404,7 @@ def test_emformer_encoder_infer():
|
||||
assert len(states) == num_encoder_layers
|
||||
for state in states:
|
||||
assert len(state) == 4
|
||||
assert state[0].shape == (M, B, D)
|
||||
assert state[0].shape == (max_memory_size, B, D)
|
||||
assert state[1].shape == (left_context_length, B, D)
|
||||
assert state[2].shape == (left_context_length, B, D)
|
||||
assert torch.equal(
|
||||
@ -391,7 +427,7 @@ def test_emformer_encoder_forward_infer_consistency():
|
||||
kernel_size = 31
|
||||
memory_sizes = [0, 3]
|
||||
|
||||
for M in memory_sizes:
|
||||
for max_memory_size in memory_sizes:
|
||||
encoder = EmformerEncoder(
|
||||
chunk_length=chunk_length,
|
||||
d_model=D,
|
||||
@ -400,7 +436,7 @@ def test_emformer_encoder_forward_infer_consistency():
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
max_memory_size=max_memory_size,
|
||||
)
|
||||
encoder.eval()
|
||||
|
||||
@ -449,9 +485,9 @@ def test_emformer_forward():
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
M = 3
|
||||
max_memory_size = 3
|
||||
else:
|
||||
M = 0
|
||||
max_memory_size = 0
|
||||
model = Emformer(
|
||||
num_features=num_features,
|
||||
chunk_length=chunk_length,
|
||||
@ -460,7 +496,7 @@ def test_emformer_forward():
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
max_memory_size=max_memory_size,
|
||||
)
|
||||
x = torch.randn(B, U + right_context_length + 3, num_features)
|
||||
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
|
||||
@ -481,7 +517,7 @@ def test_emformer_infer():
|
||||
num_features = 80
|
||||
chunk_length = 8
|
||||
U = chunk_length
|
||||
left_context_length, right_context_length = 128, 4
|
||||
left_context_length, right_context_length = 32, 4
|
||||
B, D = 2, 256
|
||||
num_chunks = 3
|
||||
num_encoder_layers = 2
|
||||
@ -489,9 +525,9 @@ def test_emformer_infer():
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
M = 3
|
||||
max_memory_size = 32
|
||||
else:
|
||||
M = 0
|
||||
max_memory_size = 0
|
||||
model = Emformer(
|
||||
num_features=num_features,
|
||||
chunk_length=chunk_length,
|
||||
@ -501,7 +537,7 @@ def test_emformer_infer():
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
max_memory_size=max_memory_size,
|
||||
)
|
||||
states = None
|
||||
conv_caches = None
|
||||
@ -523,7 +559,7 @@ def test_emformer_infer():
|
||||
assert len(states) == num_encoder_layers
|
||||
for state in states:
|
||||
assert len(state) == 4
|
||||
assert state[0].shape == (M, B, D)
|
||||
assert state[0].shape == (max_memory_size, B, D)
|
||||
assert state[1].shape == (left_context_length // 4, B, D)
|
||||
assert state[2].shape == (left_context_length // 4, B, D)
|
||||
assert torch.equal(
|
||||
|
Loading…
x
Reference in New Issue
Block a user