mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Support state stacking and unstacking operations for emformer_pruned_transducer_stateless/emformer.py
This commit is contained in:
parent
39c6c1be87
commit
328ad280a4
@ -77,6 +77,75 @@ def _gen_attention_mask_block(
|
|||||||
return torch.cat(mask_block, dim=1)
|
return torch.cat(mask_block, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def unstack_states(
|
||||||
|
states: List[List[torch.Tensor]],
|
||||||
|
) -> List[List[List[torch.Tensor]]]:
|
||||||
|
"""Unstack the emformer state corresponding to a batch of utterances
|
||||||
|
into a list of states, were the i-th entry is the state from the i-th
|
||||||
|
utterance in the batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
states:
|
||||||
|
A list-of-list of tensors. ``len(states)`` equals to number of
|
||||||
|
layers in the emformer. ``states[i]]`` contains the states for
|
||||||
|
the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape
|
||||||
|
``(T, N, C)`` or a 2-D tensor of shape ``(C, N)``
|
||||||
|
"""
|
||||||
|
batch_size = states[0][0].size(1)
|
||||||
|
num_layers = len(states)
|
||||||
|
|
||||||
|
ans = [None] * batch_size
|
||||||
|
for i in range(batch_size):
|
||||||
|
ans[i] = [[] for _ in range(num_layers)]
|
||||||
|
|
||||||
|
for li, layer in enumerate(states):
|
||||||
|
for s in layer:
|
||||||
|
s_list = s.unbind(dim=1)
|
||||||
|
# We will use stack(dim=1) later in stack_states()
|
||||||
|
for bi, b in enumerate(ans):
|
||||||
|
b[li].append(s_list[bi])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def stack_states(
|
||||||
|
state_list: List[List[List[torch.Tensor]]],
|
||||||
|
) -> List[List[torch.Tensor]]:
|
||||||
|
"""Stack list of emformer states that correspond to separate utterances
|
||||||
|
into a single emformer state so that it can be used as an input for
|
||||||
|
emformer when those utterances are formed into a batch.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
It is the inverse of :func:`unstack_states`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_list:
|
||||||
|
Each element in state_list corresponding to the internal state
|
||||||
|
of the emformer model for a single utterance.
|
||||||
|
Returns:
|
||||||
|
Return a new state corresponding to a batch of utterances.
|
||||||
|
See the input argument of :func:`unstack_states` for the meaning
|
||||||
|
of the returned tensor.
|
||||||
|
"""
|
||||||
|
batch_size = len(state_list)
|
||||||
|
ans = []
|
||||||
|
for layer in state_list[0]:
|
||||||
|
# layer is a list of tensors
|
||||||
|
if batch_size > 1:
|
||||||
|
ans.append([[s] for s in layer])
|
||||||
|
# Note: We will stack ans[layer][s][] later to get ans[layer][s]
|
||||||
|
else:
|
||||||
|
ans.append([s.unsqueeze(1) for s in layer])
|
||||||
|
|
||||||
|
for b, states in enumerate(state_list[1:], 1):
|
||||||
|
for li, layer in enumerate(states):
|
||||||
|
for si, s in enumerate(layer):
|
||||||
|
ans[li][si].append(s)
|
||||||
|
if b == batch_size - 1:
|
||||||
|
ans[li][si] = torch.stack(ans[li][si], dim=1)
|
||||||
|
# We will use unbind(dim=1) later in unstack_states()
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
class EmformerAttention(nn.Module):
|
class EmformerAttention(nn.Module):
|
||||||
r"""Emformer layer attention module.
|
r"""Emformer layer attention module.
|
||||||
|
|
||||||
@ -424,9 +493,9 @@ class EmformerAttention(nn.Module):
|
|||||||
# key, value: [memory, right context, left context, uttrance]
|
# key, value: [memory, right context, left context, uttrance]
|
||||||
KV = (
|
KV = (
|
||||||
memory.size(0)
|
memory.size(0)
|
||||||
+ right_context.size(0)
|
+ right_context.size(0) # noqa
|
||||||
+ left_context_key.size(0)
|
+ left_context_key.size(0) # noqa
|
||||||
+ utterance.size(0)
|
+ utterance.size(0) # noqa
|
||||||
)
|
)
|
||||||
attention_mask = torch.zeros(Q, KV).to(
|
attention_mask = torch.zeros(Q, KV).to(
|
||||||
dtype=torch.bool, device=utterance.device
|
dtype=torch.bool, device=utterance.device
|
||||||
|
@ -682,6 +682,44 @@ def test_emformer_infer_batch_single_consistency():
|
|||||||
assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0)
|
assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_infer_states_stack():
|
||||||
|
from emformer import Emformer, unstack_states, stack_states
|
||||||
|
|
||||||
|
num_features = 80
|
||||||
|
output_dim = 1000
|
||||||
|
chunk_length = 8
|
||||||
|
U = chunk_length
|
||||||
|
L, R = 128, 4
|
||||||
|
B, D = 2, 256
|
||||||
|
num_encoder_layers = 2
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
M = 3
|
||||||
|
else:
|
||||||
|
M = 0
|
||||||
|
model = Emformer(
|
||||||
|
num_features=num_features,
|
||||||
|
output_dim=output_dim,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
subsampling_factor=4,
|
||||||
|
d_model=D,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
left_context_length=L,
|
||||||
|
right_context_length=R,
|
||||||
|
max_memory_size=M,
|
||||||
|
vgg_frontend=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.randn(B, U + R + 3, num_features)
|
||||||
|
x_lens = torch.full((B, ), U + R + 3)
|
||||||
|
logits, output_lengths, states = model.infer(x, x_lens,)
|
||||||
|
states2 = stack_states(unstack_states(states))
|
||||||
|
|
||||||
|
for ss, ss2 in zip(states, states2):
|
||||||
|
for s, s2 in zip(ss, ss2):
|
||||||
|
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_emformer_attention_forward()
|
test_emformer_attention_forward()
|
||||||
test_emformer_attention_infer()
|
test_emformer_attention_infer()
|
||||||
@ -695,3 +733,4 @@ if __name__ == "__main__":
|
|||||||
test_emformer_layer_forward_infer_consistency()
|
test_emformer_layer_forward_infer_consistency()
|
||||||
test_emformer_encoder_forward_infer_consistency()
|
test_emformer_encoder_forward_infer_consistency()
|
||||||
test_emformer_infer_batch_single_consistency()
|
test_emformer_infer_batch_single_consistency()
|
||||||
|
test_emformer_infer_states_stack()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user