mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
fix doc of stack and unstack, test case with batch_size=1
This commit is contained in:
parent
adcbb4076d
commit
a1cbe1fd9c
@ -43,12 +43,19 @@ def unstack_states(
|
||||
states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
|
||||
) -> List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]:
|
||||
"""Unstack the emformer state corresponding to a batch of utterances
|
||||
into a list of states, were the i-th entry is the state from the i-th
|
||||
into a list of states, where the i-th entry is the state from the i-th
|
||||
utterance in the batch.
|
||||
|
||||
Args:
|
||||
states:
|
||||
A list of tuples.
|
||||
A tuple of 2 elements.
|
||||
``states[0]`` is the attention caches of a batch of utterance.
|
||||
``states[1]`` is the convolution caches of a batch of utterance.
|
||||
``len(states[0])`` and ``len(states[1])`` both eqaul to number of layers. # noqa
|
||||
|
||||
Returns:
|
||||
A list of states.
|
||||
``states[i]`` is a tuple of 2 elements of i-th utterance.
|
||||
``states[i][0]`` is the attention caches of i-th utterance.
|
||||
``states[i][1]`` is the convolution caches of i-th utterance.
|
||||
``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa
|
||||
@ -85,7 +92,6 @@ def unstack_states(
|
||||
def stack_states(
|
||||
state_list: List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]
|
||||
) -> Tuple[List[List[torch.Tensor]], List[torch.Tensor]]:
|
||||
# TODO: modify doc
|
||||
"""Stack list of emformer states that correspond to separate utterances
|
||||
into a single emformer state so that it can be used as an input for
|
||||
emformer when those utterances are formed into a batch.
|
||||
@ -97,8 +103,13 @@ def stack_states(
|
||||
state_list:
|
||||
Each element in state_list corresponding to the internal state
|
||||
of the emformer model for a single utterance.
|
||||
``states[i]`` is a tuple of 2 elements of i-th utterance.
|
||||
``states[i][0]`` is the attention caches of i-th utterance.
|
||||
``states[i][1]`` is the convolution caches of i-th utterance.
|
||||
``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa
|
||||
|
||||
Returns:
|
||||
Return a new state corresponding to a batch of utterances.
|
||||
A new state corresponding to a batch of utterances.
|
||||
See the input argument of :func:`unstack_states` for the meaning
|
||||
of the returned tensor.
|
||||
"""
|
||||
|
@ -85,7 +85,6 @@ def test_state_stack_unstack():
|
||||
left_context_length = 32
|
||||
right_context_length = 8
|
||||
memory_size = 32
|
||||
batch_size = 2
|
||||
|
||||
model = Emformer(
|
||||
num_features=num_features,
|
||||
@ -98,6 +97,8 @@ def test_state_stack_unstack():
|
||||
right_context_length=right_context_length,
|
||||
memory_size=memory_size,
|
||||
)
|
||||
|
||||
for batch_size in [1, 2]:
|
||||
attn_caches = [
|
||||
[
|
||||
torch.zeros(memory_size, batch_size, encoder_dim),
|
||||
|
Loading…
x
Reference in New Issue
Block a user