fix doc of stack and unstack, test case with batch_size=1

This commit is contained in:
yaozengwei 2022-06-13 12:47:35 +08:00
parent adcbb4076d
commit a1cbe1fd9c
2 changed files with 49 additions and 37 deletions

View File

@ -43,12 +43,19 @@ def unstack_states(
states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
) -> List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]:
"""Unstack the emformer state corresponding to a batch of utterances
into a list of states, were the i-th entry is the state from the i-th
into a list of states, where the i-th entry is the state from the i-th
utterance in the batch.
Args:
states:
A list of tuples.
A tuple of 2 elements.
``states[0]`` is the attention caches of a batch of utterance.
``states[1]`` is the convolution caches of a batch of utterance.
``len(states[0])`` and ``len(states[1])`` both eqaul to number of layers. # noqa
Returns:
A list of states.
``states[i]`` is a tuple of 2 elements of i-th utterance.
``states[i][0]`` is the attention caches of i-th utterance.
``states[i][1]`` is the convolution caches of i-th utterance.
``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa
@ -85,7 +92,6 @@ def unstack_states(
def stack_states(
state_list: List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]
) -> Tuple[List[List[torch.Tensor]], List[torch.Tensor]]:
# TODO: modify doc
"""Stack list of emformer states that correspond to separate utterances
into a single emformer state so that it can be used as an input for
emformer when those utterances are formed into a batch.
@ -97,8 +103,13 @@ def stack_states(
state_list:
Each element in state_list corresponding to the internal state
of the emformer model for a single utterance.
``states[i]`` is a tuple of 2 elements of i-th utterance.
``states[i][0]`` is the attention caches of i-th utterance.
``states[i][1]`` is the convolution caches of i-th utterance.
``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa
Returns:
Return a new state corresponding to a batch of utterances.
A new state corresponding to a batch of utterances.
See the input argument of :func:`unstack_states` for the meaning
of the returned tensor.
"""

View File

@ -85,7 +85,6 @@ def test_state_stack_unstack():
left_context_length = 32
right_context_length = 8
memory_size = 32
batch_size = 2
model = Emformer(
num_features=num_features,
@ -98,6 +97,8 @@ def test_state_stack_unstack():
right_context_length=right_context_length,
memory_size=memory_size,
)
for batch_size in [1, 2]:
attn_caches = [
[
torch.zeros(memory_size, batch_size, encoder_dim),