mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
Support state stacking and unstacking operations for emformer_pruned_transducer_stateless/emformer.py
This commit is contained in:
parent
39c6c1be87
commit
328ad280a4
@ -77,6 +77,75 @@ def _gen_attention_mask_block(
|
||||
return torch.cat(mask_block, dim=1)
|
||||
|
||||
|
||||
def unstack_states(
|
||||
states: List[List[torch.Tensor]],
|
||||
) -> List[List[List[torch.Tensor]]]:
|
||||
"""Unstack the emformer state corresponding to a batch of utterances
|
||||
into a list of states, were the i-th entry is the state from the i-th
|
||||
utterance in the batch.
|
||||
|
||||
Args:
|
||||
states:
|
||||
A list-of-list of tensors. ``len(states)`` equals to number of
|
||||
layers in the emformer. ``states[i]]`` contains the states for
|
||||
the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape
|
||||
``(T, N, C)`` or a 2-D tensor of shape ``(C, N)``
|
||||
"""
|
||||
batch_size = states[0][0].size(1)
|
||||
num_layers = len(states)
|
||||
|
||||
ans = [None] * batch_size
|
||||
for i in range(batch_size):
|
||||
ans[i] = [[] for _ in range(num_layers)]
|
||||
|
||||
for li, layer in enumerate(states):
|
||||
for s in layer:
|
||||
s_list = s.unbind(dim=1)
|
||||
# We will use stack(dim=1) later in stack_states()
|
||||
for bi, b in enumerate(ans):
|
||||
b[li].append(s_list[bi])
|
||||
return ans
|
||||
|
||||
|
||||
def stack_states(
|
||||
state_list: List[List[List[torch.Tensor]]],
|
||||
) -> List[List[torch.Tensor]]:
|
||||
"""Stack list of emformer states that correspond to separate utterances
|
||||
into a single emformer state so that it can be used as an input for
|
||||
emformer when those utterances are formed into a batch.
|
||||
|
||||
Note:
|
||||
It is the inverse of :func:`unstack_states`.
|
||||
|
||||
Args:
|
||||
state_list:
|
||||
Each element in state_list corresponding to the internal state
|
||||
of the emformer model for a single utterance.
|
||||
Returns:
|
||||
Return a new state corresponding to a batch of utterances.
|
||||
See the input argument of :func:`unstack_states` for the meaning
|
||||
of the returned tensor.
|
||||
"""
|
||||
batch_size = len(state_list)
|
||||
ans = []
|
||||
for layer in state_list[0]:
|
||||
# layer is a list of tensors
|
||||
if batch_size > 1:
|
||||
ans.append([[s] for s in layer])
|
||||
# Note: We will stack ans[layer][s][] later to get ans[layer][s]
|
||||
else:
|
||||
ans.append([s.unsqueeze(1) for s in layer])
|
||||
|
||||
for b, states in enumerate(state_list[1:], 1):
|
||||
for li, layer in enumerate(states):
|
||||
for si, s in enumerate(layer):
|
||||
ans[li][si].append(s)
|
||||
if b == batch_size - 1:
|
||||
ans[li][si] = torch.stack(ans[li][si], dim=1)
|
||||
# We will use unbind(dim=1) later in unstack_states()
|
||||
return ans
|
||||
|
||||
|
||||
class EmformerAttention(nn.Module):
|
||||
r"""Emformer layer attention module.
|
||||
|
||||
@ -424,9 +493,9 @@ class EmformerAttention(nn.Module):
|
||||
# key, value: [memory, right context, left context, uttrance]
|
||||
KV = (
|
||||
memory.size(0)
|
||||
+ right_context.size(0)
|
||||
+ left_context_key.size(0)
|
||||
+ utterance.size(0)
|
||||
+ right_context.size(0) # noqa
|
||||
+ left_context_key.size(0) # noqa
|
||||
+ utterance.size(0) # noqa
|
||||
)
|
||||
attention_mask = torch.zeros(Q, KV).to(
|
||||
dtype=torch.bool, device=utterance.device
|
||||
|
@ -682,6 +682,44 @@ def test_emformer_infer_batch_single_consistency():
|
||||
assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0)
|
||||
|
||||
|
||||
def test_emformer_infer_states_stack():
|
||||
from emformer import Emformer, unstack_states, stack_states
|
||||
|
||||
num_features = 80
|
||||
output_dim = 1000
|
||||
chunk_length = 8
|
||||
U = chunk_length
|
||||
L, R = 128, 4
|
||||
B, D = 2, 256
|
||||
num_encoder_layers = 2
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
M = 3
|
||||
else:
|
||||
M = 0
|
||||
model = Emformer(
|
||||
num_features=num_features,
|
||||
output_dim=output_dim,
|
||||
chunk_length=chunk_length,
|
||||
subsampling_factor=4,
|
||||
d_model=D,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
left_context_length=L,
|
||||
right_context_length=R,
|
||||
max_memory_size=M,
|
||||
vgg_frontend=False,
|
||||
)
|
||||
|
||||
x = torch.randn(B, U + R + 3, num_features)
|
||||
x_lens = torch.full((B, ), U + R + 3)
|
||||
logits, output_lengths, states = model.infer(x, x_lens,)
|
||||
states2 = stack_states(unstack_states(states))
|
||||
|
||||
for ss, ss2 in zip(states, states2):
|
||||
for s, s2 in zip(ss, ss2):
|
||||
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_emformer_attention_forward()
|
||||
test_emformer_attention_infer()
|
||||
@ -695,3 +733,4 @@ if __name__ == "__main__":
|
||||
test_emformer_layer_forward_infer_consistency()
|
||||
test_emformer_encoder_forward_infer_consistency()
|
||||
test_emformer_infer_batch_single_consistency()
|
||||
test_emformer_infer_states_stack()
|
||||
|
Loading…
x
Reference in New Issue
Block a user