mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
fix doc of stack and unstack, test case with batch_size=1
This commit is contained in:
parent
adcbb4076d
commit
a1cbe1fd9c
@ -43,15 +43,22 @@ def unstack_states(
|
||||
states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
|
||||
) -> List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]:
|
||||
"""Unstack the emformer state corresponding to a batch of utterances
|
||||
into a list of states, were the i-th entry is the state from the i-th
|
||||
into a list of states, where the i-th entry is the state from the i-th
|
||||
utterance in the batch.
|
||||
|
||||
Args:
|
||||
states:
|
||||
A list of tuples.
|
||||
``states[i][0]`` is the attention caches of i-th utterance.
|
||||
``states[i][1]`` is the convolution caches of i-th utterance.
|
||||
``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa
|
||||
A tuple of 2 elements.
|
||||
``states[0]`` is the attention caches of a batch of utterance.
|
||||
``states[1]`` is the convolution caches of a batch of utterance.
|
||||
``len(states[0])`` and ``len(states[1])`` both eqaul to number of layers. # noqa
|
||||
|
||||
Returns:
|
||||
A list of states.
|
||||
``states[i]`` is a tuple of 2 elements of i-th utterance.
|
||||
``states[i][0]`` is the attention caches of i-th utterance.
|
||||
``states[i][1]`` is the convolution caches of i-th utterance.
|
||||
``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa
|
||||
"""
|
||||
|
||||
attn_caches, conv_caches = states
|
||||
@ -85,7 +92,6 @@ def unstack_states(
|
||||
def stack_states(
|
||||
state_list: List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]
|
||||
) -> Tuple[List[List[torch.Tensor]], List[torch.Tensor]]:
|
||||
# TODO: modify doc
|
||||
"""Stack list of emformer states that correspond to separate utterances
|
||||
into a single emformer state so that it can be used as an input for
|
||||
emformer when those utterances are formed into a batch.
|
||||
@ -97,8 +103,13 @@ def stack_states(
|
||||
state_list:
|
||||
Each element in state_list corresponding to the internal state
|
||||
of the emformer model for a single utterance.
|
||||
``states[i]`` is a tuple of 2 elements of i-th utterance.
|
||||
``states[i][0]`` is the attention caches of i-th utterance.
|
||||
``states[i][1]`` is the convolution caches of i-th utterance.
|
||||
``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa
|
||||
|
||||
Returns:
|
||||
Return a new state corresponding to a batch of utterances.
|
||||
A new state corresponding to a batch of utterances.
|
||||
See the input argument of :func:`unstack_states` for the meaning
|
||||
of the returned tensor.
|
||||
"""
|
||||
|
@ -85,7 +85,6 @@ def test_state_stack_unstack():
|
||||
left_context_length = 32
|
||||
right_context_length = 8
|
||||
memory_size = 32
|
||||
batch_size = 2
|
||||
|
||||
model = Emformer(
|
||||
num_features=num_features,
|
||||
@ -98,40 +97,42 @@ def test_state_stack_unstack():
|
||||
right_context_length=right_context_length,
|
||||
memory_size=memory_size,
|
||||
)
|
||||
attn_caches = [
|
||||
[
|
||||
torch.zeros(memory_size, batch_size, encoder_dim),
|
||||
torch.zeros(left_context_length // 4, batch_size, encoder_dim),
|
||||
torch.zeros(
|
||||
left_context_length // 4,
|
||||
batch_size,
|
||||
encoder_dim,
|
||||
),
|
||||
|
||||
for batch_size in [1, 2]:
|
||||
attn_caches = [
|
||||
[
|
||||
torch.zeros(memory_size, batch_size, encoder_dim),
|
||||
torch.zeros(left_context_length // 4, batch_size, encoder_dim),
|
||||
torch.zeros(
|
||||
left_context_length // 4,
|
||||
batch_size,
|
||||
encoder_dim,
|
||||
),
|
||||
]
|
||||
for _ in range(num_encoder_layers)
|
||||
]
|
||||
for _ in range(num_encoder_layers)
|
||||
]
|
||||
conv_caches = [
|
||||
torch.zeros(batch_size, encoder_dim, kernel_size - 1)
|
||||
for _ in range(num_encoder_layers)
|
||||
]
|
||||
states = [attn_caches, conv_caches]
|
||||
x = torch.randn(batch_size, 23, num_features)
|
||||
x_lens = torch.full((batch_size,), 23)
|
||||
num_processed_frames = torch.full((batch_size,), 0)
|
||||
y, y_lens, states = model.infer(
|
||||
x, x_lens, num_processed_frames=num_processed_frames, states=states
|
||||
)
|
||||
conv_caches = [
|
||||
torch.zeros(batch_size, encoder_dim, kernel_size - 1)
|
||||
for _ in range(num_encoder_layers)
|
||||
]
|
||||
states = [attn_caches, conv_caches]
|
||||
x = torch.randn(batch_size, 23, num_features)
|
||||
x_lens = torch.full((batch_size,), 23)
|
||||
num_processed_frames = torch.full((batch_size,), 0)
|
||||
y, y_lens, states = model.infer(
|
||||
x, x_lens, num_processed_frames=num_processed_frames, states=states
|
||||
)
|
||||
|
||||
state_list = unstack_states(states)
|
||||
states2 = stack_states(state_list)
|
||||
state_list = unstack_states(states)
|
||||
states2 = stack_states(state_list)
|
||||
|
||||
for ss, ss2 in zip(states[0], states2[0]):
|
||||
for s, s2 in zip(ss, ss2):
|
||||
for ss, ss2 in zip(states[0], states2[0]):
|
||||
for s, s2 in zip(ss, ss2):
|
||||
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
|
||||
|
||||
for s, s2 in zip(states[1], states2[1]):
|
||||
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
|
||||
|
||||
for s, s2 in zip(states[1], states2[1]):
|
||||
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
|
||||
|
||||
|
||||
def test_torchscript_consistency_infer():
|
||||
r"""Verify that scripting Emformer does not change the behavior of method `infer`.""" # noqa
|
||||
|
Loading…
x
Reference in New Issue
Block a user