add pos enc on memory, left context, chunk, and right context (in key)

This commit is contained in:
yaozengwei 2022-05-20 11:13:29 +08:00
parent f94ad976c6
commit 19a8700301
2 changed files with 242 additions and 58 deletions

View File

@ -450,6 +450,10 @@ class EmformerAttention(nn.Module):
Embedding dimension. Embedding dimension.
nhead (int): nhead (int):
Number of attention heads in each Emformer layer. Number of attention heads in each Emformer layer.
chunk_length (int):
Length of each input chunk.
right_context_length (int):
Length of right context.
dropout (float, optional): dropout (float, optional):
Dropout probability. (Default: 0.0) Dropout probability. (Default: 0.0)
tanh_on_mem (bool, optional): tanh_on_mem (bool, optional):
@ -462,6 +466,8 @@ class EmformerAttention(nn.Module):
self, self,
embed_dim: int, embed_dim: int,
nhead: int, nhead: int,
chunk_length: int,
right_context_length: int,
dropout: float = 0.0, dropout: float = 0.0,
tanh_on_mem: bool = False, tanh_on_mem: bool = False,
negative_inf: float = -1e8, negative_inf: float = -1e8,
@ -479,6 +485,8 @@ class EmformerAttention(nn.Module):
self.tanh_on_mem = tanh_on_mem self.tanh_on_mem = tanh_on_mem
self.negative_inf = negative_inf self.negative_inf = negative_inf
self.head_dim = embed_dim // nhead self.head_dim = embed_dim // nhead
self.chunk_length = chunk_length
self.right_context_length = right_context_length
self.dropout = dropout self.dropout = dropout
self.emb_to_key_value = ScaledLinear( self.emb_to_key_value = ScaledLinear(
@ -572,13 +580,16 @@ class EmformerAttention(nn.Module):
Args: Args:
x: Input tensor, of shape (B, nhead, U, PE). x: Input tensor, of shape (B, nhead, U, PE).
U is the length of query vector. U is the length of query vector.
For training and validation mode, PE = 2 * U - 1; For training and validation mode,
for inference mode, PE = L + 2 * U - 1. PE = 2 * U + right_context_length - 1.
For inference mode,
PE = tot_left_length + 2 * U + right_context_length - 1,
where tot_left_length = M * chunk_length.
Returns: Returns:
A tensor of shape (B, nhead, U, out_len). A tensor of shape (B, nhead, U, out_len).
For non-infer mode, out_len = U; For training and validation mode, out_len = U + right_context_length.
for infer mode, out_len = L + U. For inference mode, out_len = tot_left_length + U + right_context_length. # noqa
""" """
B, nhead, U, PE = x.size() B, nhead, U, PE = x.size()
B_stride = x.stride(0) B_stride = x.stride(0)
@ -592,6 +603,33 @@ class EmformerAttention(nn.Module):
storage_offset=PE_stride * (U - 1), storage_offset=PE_stride * (U - 1),
) )
def _get_right_context_part(
self, matrix_bd_utterance: torch.Tensor
) -> torch.Tensor:
"""
Args:
matrix_bd_utterance:
(B * nhead, U, U + right_context_length)
Returns:
A tensor of shape (B * nhead, U, R),
where R = num_chunks * right_context_length.
"""
assert self.right_context_length > 0
U = matrix_bd_utterance.size(1)
num_chunks = math.ceil(U / self.chunk_length)
right_context_blocks = []
for i in range(num_chunks - 1):
start_idx = (i + 1) * self.chunk_length
end_idx = start_idx + self.right_context_length
right_context_blocks.append(
matrix_bd_utterance[:, :, start_idx:end_idx]
)
right_context_blocks.append(
matrix_bd_utterance[:, :, -self.right_context_length :]
)
return torch.cat(right_context_blocks, dim=2)
def _forward_impl( def _forward_impl(
self, self,
utterance: torch.Tensor, utterance: torch.Tensor,
@ -603,7 +641,15 @@ class EmformerAttention(nn.Module):
pos_emb: torch.Tensor, pos_emb: torch.Tensor,
left_context_key: Optional[torch.Tensor] = None, left_context_key: Optional[torch.Tensor] = None,
left_context_val: Optional[torch.Tensor] = None, left_context_val: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: need_weights=False,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""Underlying chunk-wise attention implementation.""" """Underlying chunk-wise attention implementation."""
U, B, _ = utterance.size() U, B, _ = utterance.size()
R = right_context.size(0) R = right_context.size(0)
@ -664,10 +710,12 @@ class EmformerAttention(nn.Module):
if left_context_key is not None and left_context_val is not None: if left_context_key is not None and left_context_val is not None:
# inference mode # inference mode
L = left_context_key.size(0) L = left_context_key.size(0)
assert PE == L + 2 * U - 1 tot_left_length = M * self.chunk_length if M > 0 else L
assert tot_left_length >= L
assert PE == tot_left_length + 2 * U + self.right_context_length - 1
else: else:
# training and validation mode # training and validation mode
assert PE == 2 * U - 1 assert PE == 2 * U + self.right_context_length - 1
pos_emb = ( pos_emb = (
self.linear_pos(pos_emb) self.linear_pos(pos_emb)
.view(PE, self.nhead, self.head_dim) .view(PE, self.nhead, self.head_dim)
@ -679,13 +727,49 @@ class EmformerAttention(nn.Module):
) # (B, nhead, U, PE) ) # (B, nhead, U, PE)
# rel-shift operation # rel-shift operation
matrix_bd_utterance = self._rel_shift(matrix_bd_utterance) matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
# (B, nhead, U, U) for training and validation mode; # (B, nhead, U, U + right_context_length) for training and validation mode; # noqa
# (B, nhead, U, L + U) for inference mode. # (B, nhead, U, tot_left_length + U + right_context_length) for inference mode. # noqa
matrix_bd_utterance = matrix_bd_utterance.contiguous().view( matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
B * self.nhead, U, -1 B * self.nhead, U, -1
) )
matrix_bd = torch.zeros_like(matrix_ac) matrix_bd = torch.zeros_like(matrix_ac)
matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance if left_context_key is not None and left_context_val is not None:
# inference mode
# key: [memory, right context, left context, utterance]
# for memory
if M > 0:
matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d(
matrix_bd_utterance[:, :, :tot_left_length].unsqueeze(1),
kernel_size=(1, self.chunk_length),
stride=(1, self.chunk_length),
).squeeze(1)
# for right_context
if R > 0:
matrix_bd[:, R : R + U, M : M + R] = matrix_bd_utterance[
:, :, tot_left_length + U :
]
# for left_context and utterance
matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance[
:, :, tot_left_length - L : tot_left_length + U
]
else:
# training and validation mode
# key: [memory, right context, utterance]
# for memory
if M > 0:
matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d(
matrix_bd_utterance[:, :, :U].unsqueeze(1),
kernel_size=(1, self.chunk_length),
stride=(1, self.chunk_length),
ceil_mode=True,
).squeeze(1)[:, :, :-1]
# for right_context
if R > 0:
matrix_bd[
:, R : R + U, M : M + R
] = self._get_right_context_part(matrix_bd_utterance)
# for utterance
matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance[:, :, :U]
attention_weights = matrix_ac + matrix_bd attention_weights = matrix_ac + matrix_bd
@ -717,7 +801,29 @@ class EmformerAttention(nn.Module):
else: else:
output_memory = torch.clamp(output_memory, min=-10, max=10) output_memory = torch.clamp(output_memory, min=-10, max=10)
return output_right_context_utterance, output_memory, key, value if need_weights:
# average over attention heads
attention_probs = attention_probs.reshape(B, self.nhead, Q, KV)
attention_probs = attention_probs.sum(dim=1) / self.nhead
probs_memory = attention_probs[:, R : R + U, :M].sum(dim=2)
probs_frames = attention_probs[:, R : R + U, M:].sum(dim=2)
return (
output_right_context_utterance,
output_memory,
key,
value,
probs_memory,
probs_frames,
)
return (
output_right_context_utterance,
output_memory,
key,
value,
None,
None,
)
def forward( def forward(
self, self,
@ -766,10 +872,11 @@ class EmformerAttention(nn.Module):
s_i (in query) -> l_i, c_i, r_i (in key). s_i (in query) -> l_i, c_i, r_i (in key).
Relative positional encoding is applied on the part of attention between Relative positional encoding is applied on the part of attention between
utterance (in query) and utterance (in key). Actually, it is applied on [utterance] (in query) and [memory, right_context, utterance] (in key).
the part of attention between each chunk (in query) and itself as well Actually, it is applied on the part of attention between each chunk
as its left context (in key), after applying the mask: (in query) and itself, its memory vectors, left context, and right
c_i -> l_i, c_i. context (in key), after applying the mask:
c_i (in query) -> l_i, c_i, r_i, m_i (in key).
Args: Args:
utterance (torch.Tensor): utterance (torch.Tensor):
@ -791,18 +898,23 @@ class EmformerAttention(nn.Module):
attention, with shape (Q, KV). attention, with shape (Q, KV).
pos_emb (torch.Tensor): pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D). Position encoding embedding, with shape (PE, D).
where PE = 2 * U - 1. where PE = 2 * U + right_context_length - 1.
Returns: Returns:
A tuple containing 2 tensors: A tuple containing 2 tensors:
- output of right context and utterance, with shape (R + U, B, D). - output of right context and utterance, with shape (R + U, B, D).
- memory output, with shape (M, B, D), where M = S - 1 or M = 0. - memory output, with shape (M, B, D), where M = S - 1 or M = 0.
- summary of attention weights on memory, with shape (B, U).
- summary of attention weights on left context, utterance, and
right context, with shape (B, U).
""" """
( (
output_right_context_utterance, output_right_context_utterance,
output_memory, output_memory,
_, _,
_, _,
probs_memory,
probs_frames,
) = self._forward_impl( ) = self._forward_impl(
utterance, utterance,
lengths, lengths,
@ -811,8 +923,14 @@ class EmformerAttention(nn.Module):
memory, memory,
attention_mask, attention_mask,
pos_emb, pos_emb,
need_weights=True,
)
return (
output_right_context_utterance,
output_memory[:-1],
probs_memory,
probs_frames,
) )
return output_right_context_utterance, output_memory[:-1]
def infer( def infer(
self, self,
@ -849,9 +967,9 @@ class EmformerAttention(nn.Module):
left context, chunk, right context, memory vectors (in key); left context, chunk, right context, memory vectors (in key);
summary (in query) -> left context, chunk, right context (in key). summary (in query) -> left context, chunk, right context (in key).
Relative positional encoding is applied on the part of attention between Relative positional encoding is applied on the part of attention:
chunk (in query) and chunk itself as well as its left context (in key): chunk (in query) ->
chunk (in query) -> left context, chunk (in key). left context, chunk, right context, memory vectors (in key);
Args: Args:
utterance (torch.Tensor): utterance (torch.Tensor):
@ -874,7 +992,7 @@ class EmformerAttention(nn.Module):
with shape (L, B, D), where L <= left_context_length. with shape (L, B, D), where L <= left_context_length.
pos_emb (torch.Tensor): pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D), Position encoding embedding, with shape (PE, D),
where PE = L + 2 * U - 1. where PE = M * chunk_length + 2 * U - 1 if M > 0 else L + 2 * U - 1.
Returns: Returns:
A tuple containing 4 tensors: A tuple containing 4 tensors:
@ -905,6 +1023,8 @@ class EmformerAttention(nn.Module):
output_memory, output_memory,
key, key,
value, value,
_,
_,
) = self._forward_impl( ) = self._forward_impl(
utterance, utterance,
lengths, lengths,
@ -974,6 +1094,8 @@ class EmformerEncoderLayer(nn.Module):
self.attention = EmformerAttention( self.attention = EmformerAttention(
embed_dim=d_model, embed_dim=d_model,
nhead=nhead, nhead=nhead,
chunk_length=chunk_length,
right_context_length=right_context_length,
dropout=dropout, dropout=dropout,
tanh_on_mem=tanh_on_mem, tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf, negative_inf=negative_inf,
@ -1018,6 +1140,7 @@ class EmformerEncoderLayer(nn.Module):
self.layer_dropout = layer_dropout self.layer_dropout = layer_dropout
self.left_context_length = left_context_length self.left_context_length = left_context_length
self.chunk_length = chunk_length self.chunk_length = chunk_length
self.right_context_length = right_context_length
self.max_memory_size = max_memory_size self.max_memory_size = max_memory_size
self.d_model = d_model self.d_model = d_model
self.use_memory = max_memory_size > 0 self.use_memory = max_memory_size > 0
@ -1140,7 +1263,7 @@ class EmformerEncoderLayer(nn.Module):
summary = torch.empty(0).to( summary = torch.empty(0).to(
dtype=utterance.dtype, device=utterance.device dtype=utterance.dtype, device=utterance.device
) )
output_right_context_utterance, output_memory = self.attention( output_right_context_utterance, output_memory, _, _ = self.attention(
utterance=utterance, utterance=utterance,
lengths=lengths, lengths=lengths,
right_context=right_context, right_context=right_context,
@ -1190,14 +1313,26 @@ class EmformerEncoderLayer(nn.Module):
summary = torch.empty(0).to( summary = torch.empty(0).to(
dtype=utterance.dtype, device=utterance.device dtype=utterance.dtype, device=utterance.device
) )
# pos_emb is of shape [PE, D], where PE = L + 2 * U - 1,
# for query of [utterance] (i), key-value [left_context, utterance] (j),
# the max relative distance i - j is L + U - 1
# the min relative distance i - j is -(U - 1)
L = left_context_key.size(0) # L <= left_context_length
U = utterance.size(0) U = utterance.size(0)
PE = L + 2 * U - 1 # pos_emb is of shape [PE, D], where PE = M * chunk_length + 2 * U - 1,
tot_PE = self.left_context_length + 2 * U - 1 # for query of [utterance] (i), key-value [memory vectors, left context, utterance, right context] (j) # noqa
# the max relative distance i - j is M * chunk_length + U - 1
# the min relative distance i - j is -(U + right_context_length - 1)
M = pre_memory.size(0) # M <= max_memory_size
if self.max_memory_size > 0:
PE = M * self.chunk_length + 2 * U + self.right_context_length - 1
tot_PE = (
self.max_memory_size * self.chunk_length
+ 2 * U
+ self.right_context_length
- 1
)
else:
L = left_context_key.size(0)
PE = L + 2 * U + self.right_context_length - 1
tot_PE = (
self.left_context_length + 2 * U + self.right_context_length - 1
)
assert pos_emb.size(0) == tot_PE assert pos_emb.size(0) == tot_PE
pos_emb = pos_emb[tot_PE - PE :] pos_emb = pos_emb[tot_PE - PE :]
( (
@ -1661,7 +1796,9 @@ class EmformerEncoder(nn.Module):
right_context at the end. right_context at the end.
""" """
U = x.size(0) - self.right_context_length U = x.size(0) - self.right_context_length
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U) x, pos_emb = self.encoder_pos(
x, pos_len=U, neg_len=U + self.right_context_length
)
right_context = self._gen_right_context(x) right_context = self._gen_right_context(x)
utterance = x[:U] utterance = x[:U]
@ -1734,8 +1871,12 @@ class EmformerEncoder(nn.Module):
f"for dimension 1 of x, but got {x.size(1)}." f"for dimension 1 of x, but got {x.size(1)}."
) )
pos_len = self.chunk_length + self.left_context_length pos_len = (
neg_len = self.chunk_length self.max_memory_size * self.chunk_length + self.chunk_length
if self.max_memory_size > 0
else self.left_context_length + self.chunk_length
)
neg_len = self.chunk_length + self.right_context_length
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len) x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
right_context_start_idx = x.size(0) - self.right_context_length right_context_start_idx = x.size(0) - self.right_context_length
@ -1807,6 +1948,13 @@ class Emformer(EncoderInterface):
raise NotImplementedError( raise NotImplementedError(
"right_context_length must be 0 or a mutiple of 4." "right_context_length must be 0 or a mutiple of 4."
) )
if (
max_memory_size > 0
and max_memory_size * chunk_length < left_context_length
):
raise NotImplementedError(
"max_memory_size * chunk_length can not be less than left_context_length" # noqa
)
# self.encoder_embed converts the input of shape (N, T, num_features) # self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model). # to the shape (N, T//subsampling_factor, d_model).

View File

@ -22,7 +22,12 @@ def test_emformer_attention_forward():
num_chunks = 3 num_chunks = 3
U = num_chunks * chunk_length U = num_chunks * chunk_length
R = num_chunks * right_context_length R = num_chunks * right_context_length
attention = EmformerAttention(embed_dim=D, nhead=8) attention = EmformerAttention(
embed_dim=D,
nhead=8,
chunk_length=chunk_length,
right_context_length=right_context_length,
)
for use_memory in [True, False]: for use_memory in [True, False]:
if use_memory: if use_memory:
@ -39,10 +44,15 @@ def test_emformer_attention_forward():
summary = torch.randn(S, B, D) summary = torch.randn(S, B, D)
memory = torch.randn(M, B, D) memory = torch.randn(M, B, D)
attention_mask = torch.rand(Q, KV) >= 0.5 attention_mask = torch.rand(Q, KV) >= 0.5
PE = 2 * U - 1 PE = 2 * U + right_context_length - 1
pos_emb = torch.randn(PE, D) pos_emb = torch.randn(PE, D)
output_right_context_utterance, output_memory = attention( (
output_right_context_utterance,
output_memory,
probs_memory,
probs_frames,
) = attention(
utterance, utterance,
lengths, lengths,
right_context, right_context,
@ -53,16 +63,26 @@ def test_emformer_attention_forward():
) )
assert output_right_context_utterance.shape == (R + U, B, D) assert output_right_context_utterance.shape == (R + U, B, D)
assert output_memory.shape == (M, B, D) assert output_memory.shape == (M, B, D)
assert probs_memory.shape == (B, U)
assert probs_frames.shape == (B, U)
def test_emformer_attention_infer(): def test_emformer_attention_infer():
from emformer import EmformerAttention from emformer import EmformerAttention
B, D = 2, 256 B, D = 2, 256
U = 4 chunk_length = 4
R = 2 right_context_length = 2
num_chunks = 1
U = chunk_length * num_chunks
R = right_context_length * num_chunks
L = 3 L = 3
attention = EmformerAttention(embed_dim=D, nhead=8) attention = EmformerAttention(
embed_dim=D,
nhead=8,
chunk_length=chunk_length,
right_context_length=right_context_length,
)
for use_memory in [True, False]: for use_memory in [True, False]:
if use_memory: if use_memory:
@ -78,7 +98,12 @@ def test_emformer_attention_infer():
memory = torch.randn(M, B, D) memory = torch.randn(M, B, D)
left_context_key = torch.randn(L, B, D) left_context_key = torch.randn(L, B, D)
left_context_val = torch.randn(L, B, D) left_context_val = torch.randn(L, B, D)
PE = L + 2 * U - 1 PE = (
2 * U
+ right_context_length
- 1
+ (M * chunk_length if M > 0 else L)
)
pos_emb = torch.randn(PE, D) pos_emb = torch.randn(PE, D)
( (
@ -197,7 +222,7 @@ def test_emformer_encoder_layer_forward():
right_context = torch.randn(R, B, D) right_context = torch.randn(R, B, D)
memory = torch.randn(M, B, D) memory = torch.randn(M, B, D)
attention_mask = torch.rand(Q, KV) >= 0.5 attention_mask = torch.rand(Q, KV) >= 0.5
PE = 2 * U - 1 PE = 2 * U + right_context_length - 1
pos_emb = torch.randn(PE, D) pos_emb = torch.randn(PE, D)
output_utterance, output_right_context, output_memory = layer( output_utterance, output_right_context, output_memory = layer(
@ -227,8 +252,10 @@ def test_emformer_encoder_layer_infer():
for use_memory in [True, False]: for use_memory in [True, False]:
if use_memory: if use_memory:
M = 3 max_memory_size = 3
M = 1
else: else:
max_memory_size = 0
M = 0 M = 0
layer = EmformerEncoderLayer( layer = EmformerEncoderLayer(
@ -239,7 +266,7 @@ def test_emformer_encoder_layer_infer():
cnn_module_kernel=kernel_size, cnn_module_kernel=kernel_size,
left_context_length=left_context_length, left_context_length=left_context_length,
right_context_length=right_context_length, right_context_length=right_context_length,
max_memory_size=M, max_memory_size=max_memory_size,
) )
utterance = torch.randn(U, B, D) utterance = torch.randn(U, B, D)
@ -248,7 +275,16 @@ def test_emformer_encoder_layer_infer():
right_context = torch.randn(R, B, D) right_context = torch.randn(R, B, D)
memory = torch.randn(M, B, D) memory = torch.randn(M, B, D)
state = None state = None
PE = left_context_length + 2 * U - 1 PE = (
2 * U
+ right_context_length
- 1
+ (
max_memory_size * chunk_length
if max_memory_size > 0
else left_context_length
)
)
pos_emb = torch.randn(PE, D) pos_emb = torch.randn(PE, D)
conv_cache = None conv_cache = None
( (
@ -273,7 +309,7 @@ def test_emformer_encoder_layer_infer():
else: else:
assert output_memory.shape == (0, B, D) assert output_memory.shape == (0, B, D)
assert len(output_state) == 4 assert len(output_state) == 4
assert output_state[0].shape == (M, B, D) assert output_state[0].shape == (max_memory_size, B, D)
assert output_state[1].shape == (left_context_length, B, D) assert output_state[1].shape == (left_context_length, B, D)
assert output_state[2].shape == (left_context_length, B, D) assert output_state[2].shape == (left_context_length, B, D)
assert output_state[3].shape == (1, B) assert output_state[3].shape == (1, B)
@ -334,9 +370,9 @@ def test_emformer_encoder_infer():
for use_memory in [True, False]: for use_memory in [True, False]:
if use_memory: if use_memory:
M = 3 max_memory_size = 3
else: else:
M = 0 max_memory_size = 0
encoder = EmformerEncoder( encoder = EmformerEncoder(
chunk_length=chunk_length, chunk_length=chunk_length,
@ -346,7 +382,7 @@ def test_emformer_encoder_infer():
cnn_module_kernel=kernel_size, cnn_module_kernel=kernel_size,
left_context_length=left_context_length, left_context_length=left_context_length,
right_context_length=right_context_length, right_context_length=right_context_length,
max_memory_size=M, max_memory_size=max_memory_size,
) )
states = None states = None
@ -368,7 +404,7 @@ def test_emformer_encoder_infer():
assert len(states) == num_encoder_layers assert len(states) == num_encoder_layers
for state in states: for state in states:
assert len(state) == 4 assert len(state) == 4
assert state[0].shape == (M, B, D) assert state[0].shape == (max_memory_size, B, D)
assert state[1].shape == (left_context_length, B, D) assert state[1].shape == (left_context_length, B, D)
assert state[2].shape == (left_context_length, B, D) assert state[2].shape == (left_context_length, B, D)
assert torch.equal( assert torch.equal(
@ -391,7 +427,7 @@ def test_emformer_encoder_forward_infer_consistency():
kernel_size = 31 kernel_size = 31
memory_sizes = [0, 3] memory_sizes = [0, 3]
for M in memory_sizes: for max_memory_size in memory_sizes:
encoder = EmformerEncoder( encoder = EmformerEncoder(
chunk_length=chunk_length, chunk_length=chunk_length,
d_model=D, d_model=D,
@ -400,7 +436,7 @@ def test_emformer_encoder_forward_infer_consistency():
cnn_module_kernel=kernel_size, cnn_module_kernel=kernel_size,
left_context_length=left_context_length, left_context_length=left_context_length,
right_context_length=right_context_length, right_context_length=right_context_length,
max_memory_size=M, max_memory_size=max_memory_size,
) )
encoder.eval() encoder.eval()
@ -449,9 +485,9 @@ def test_emformer_forward():
for use_memory in [True, False]: for use_memory in [True, False]:
if use_memory: if use_memory:
M = 3 max_memory_size = 3
else: else:
M = 0 max_memory_size = 0
model = Emformer( model = Emformer(
num_features=num_features, num_features=num_features,
chunk_length=chunk_length, chunk_length=chunk_length,
@ -460,7 +496,7 @@ def test_emformer_forward():
cnn_module_kernel=kernel_size, cnn_module_kernel=kernel_size,
left_context_length=left_context_length, left_context_length=left_context_length,
right_context_length=right_context_length, right_context_length=right_context_length,
max_memory_size=M, max_memory_size=max_memory_size,
) )
x = torch.randn(B, U + right_context_length + 3, num_features) x = torch.randn(B, U + right_context_length + 3, num_features)
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,)) x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
@ -481,7 +517,7 @@ def test_emformer_infer():
num_features = 80 num_features = 80
chunk_length = 8 chunk_length = 8
U = chunk_length U = chunk_length
left_context_length, right_context_length = 128, 4 left_context_length, right_context_length = 32, 4
B, D = 2, 256 B, D = 2, 256
num_chunks = 3 num_chunks = 3
num_encoder_layers = 2 num_encoder_layers = 2
@ -489,9 +525,9 @@ def test_emformer_infer():
for use_memory in [True, False]: for use_memory in [True, False]:
if use_memory: if use_memory:
M = 3 max_memory_size = 32
else: else:
M = 0 max_memory_size = 0
model = Emformer( model = Emformer(
num_features=num_features, num_features=num_features,
chunk_length=chunk_length, chunk_length=chunk_length,
@ -501,7 +537,7 @@ def test_emformer_infer():
cnn_module_kernel=kernel_size, cnn_module_kernel=kernel_size,
left_context_length=left_context_length, left_context_length=left_context_length,
right_context_length=right_context_length, right_context_length=right_context_length,
max_memory_size=M, max_memory_size=max_memory_size,
) )
states = None states = None
conv_caches = None conv_caches = None
@ -523,7 +559,7 @@ def test_emformer_infer():
assert len(states) == num_encoder_layers assert len(states) == num_encoder_layers
for state in states: for state in states:
assert len(state) == 4 assert len(state) == 4
assert state[0].shape == (M, B, D) assert state[0].shape == (max_memory_size, B, D)
assert state[1].shape == (left_context_length // 4, B, D) assert state[1].shape == (left_context_length // 4, B, D)
assert state[2].shape == (left_context_length // 4, B, D) assert state[2].shape == (left_context_length // 4, B, D)
assert torch.equal( assert torch.equal(