fix doc of stack and unstack, test case with batch_size=1

This commit is contained in:
yaozengwei 2022-06-13 12:47:35 +08:00
parent adcbb4076d
commit a1cbe1fd9c
2 changed files with 49 additions and 37 deletions

View File

@ -43,15 +43,22 @@ def unstack_states(
states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
) -> List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]:
"""Unstack the emformer state corresponding to a batch of utterances
into a list of states, were the i-th entry is the state from the i-th
into a list of states, where the i-th entry is the state from the i-th
utterance in the batch.
Args:
states:
A list of tuples.
``states[i][0]`` is the attention caches of i-th utterance.
``states[i][1]`` is the convolution caches of i-th utterance.
``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa
A tuple of 2 elements.
``states[0]`` is the attention caches of a batch of utterance.
``states[1]`` is the convolution caches of a batch of utterance.
``len(states[0])`` and ``len(states[1])`` both eqaul to number of layers. # noqa
Returns:
A list of states.
``states[i]`` is a tuple of 2 elements of i-th utterance.
``states[i][0]`` is the attention caches of i-th utterance.
``states[i][1]`` is the convolution caches of i-th utterance.
``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa
"""
attn_caches, conv_caches = states
@ -85,7 +92,6 @@ def unstack_states(
def stack_states(
state_list: List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]
) -> Tuple[List[List[torch.Tensor]], List[torch.Tensor]]:
# TODO: modify doc
"""Stack list of emformer states that correspond to separate utterances
into a single emformer state so that it can be used as an input for
emformer when those utterances are formed into a batch.
@ -97,8 +103,13 @@ def stack_states(
state_list:
Each element in state_list corresponding to the internal state
of the emformer model for a single utterance.
``states[i]`` is a tuple of 2 elements of i-th utterance.
``states[i][0]`` is the attention caches of i-th utterance.
``states[i][1]`` is the convolution caches of i-th utterance.
``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa
Returns:
Return a new state corresponding to a batch of utterances.
A new state corresponding to a batch of utterances.
See the input argument of :func:`unstack_states` for the meaning
of the returned tensor.
"""

View File

@ -85,7 +85,6 @@ def test_state_stack_unstack():
left_context_length = 32
right_context_length = 8
memory_size = 32
batch_size = 2
model = Emformer(
num_features=num_features,
@ -98,40 +97,42 @@ def test_state_stack_unstack():
right_context_length=right_context_length,
memory_size=memory_size,
)
attn_caches = [
[
torch.zeros(memory_size, batch_size, encoder_dim),
torch.zeros(left_context_length // 4, batch_size, encoder_dim),
torch.zeros(
left_context_length // 4,
batch_size,
encoder_dim,
),
for batch_size in [1, 2]:
attn_caches = [
[
torch.zeros(memory_size, batch_size, encoder_dim),
torch.zeros(left_context_length // 4, batch_size, encoder_dim),
torch.zeros(
left_context_length // 4,
batch_size,
encoder_dim,
),
]
for _ in range(num_encoder_layers)
]
for _ in range(num_encoder_layers)
]
conv_caches = [
torch.zeros(batch_size, encoder_dim, kernel_size - 1)
for _ in range(num_encoder_layers)
]
states = [attn_caches, conv_caches]
x = torch.randn(batch_size, 23, num_features)
x_lens = torch.full((batch_size,), 23)
num_processed_frames = torch.full((batch_size,), 0)
y, y_lens, states = model.infer(
x, x_lens, num_processed_frames=num_processed_frames, states=states
)
conv_caches = [
torch.zeros(batch_size, encoder_dim, kernel_size - 1)
for _ in range(num_encoder_layers)
]
states = [attn_caches, conv_caches]
x = torch.randn(batch_size, 23, num_features)
x_lens = torch.full((batch_size,), 23)
num_processed_frames = torch.full((batch_size,), 0)
y, y_lens, states = model.infer(
x, x_lens, num_processed_frames=num_processed_frames, states=states
)
state_list = unstack_states(states)
states2 = stack_states(state_list)
state_list = unstack_states(states)
states2 = stack_states(state_list)
for ss, ss2 in zip(states[0], states2[0]):
for s, s2 in zip(ss, ss2):
for ss, ss2 in zip(states[0], states2[0]):
for s, s2 in zip(ss, ss2):
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
for s, s2 in zip(states[1], states2[1]):
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
for s, s2 in zip(states[1], states2[1]):
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
def test_torchscript_consistency_infer():
r"""Verify that scripting Emformer does not change the behavior of method `infer`.""" # noqa