Support state stacking and unstacking operations for emformer_pruned_transducer_stateless/emformer.py

This commit is contained in:
yaozengwei 2022-04-19 17:58:51 +08:00
parent 39c6c1be87
commit 328ad280a4
2 changed files with 111 additions and 3 deletions

View File

@ -77,6 +77,75 @@ def _gen_attention_mask_block(
return torch.cat(mask_block, dim=1)
def unstack_states(
states: List[List[torch.Tensor]],
) -> List[List[List[torch.Tensor]]]:
"""Unstack the emformer state corresponding to a batch of utterances
into a list of states, were the i-th entry is the state from the i-th
utterance in the batch.
Args:
states:
A list-of-list of tensors. ``len(states)`` equals to number of
layers in the emformer. ``states[i]]`` contains the states for
the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape
``(T, N, C)`` or a 2-D tensor of shape ``(C, N)``
"""
batch_size = states[0][0].size(1)
num_layers = len(states)
ans = [None] * batch_size
for i in range(batch_size):
ans[i] = [[] for _ in range(num_layers)]
for li, layer in enumerate(states):
for s in layer:
s_list = s.unbind(dim=1)
# We will use stack(dim=1) later in stack_states()
for bi, b in enumerate(ans):
b[li].append(s_list[bi])
return ans
def stack_states(
state_list: List[List[List[torch.Tensor]]],
) -> List[List[torch.Tensor]]:
"""Stack list of emformer states that correspond to separate utterances
into a single emformer state so that it can be used as an input for
emformer when those utterances are formed into a batch.
Note:
It is the inverse of :func:`unstack_states`.
Args:
state_list:
Each element in state_list corresponding to the internal state
of the emformer model for a single utterance.
Returns:
Return a new state corresponding to a batch of utterances.
See the input argument of :func:`unstack_states` for the meaning
of the returned tensor.
"""
batch_size = len(state_list)
ans = []
for layer in state_list[0]:
# layer is a list of tensors
if batch_size > 1:
ans.append([[s] for s in layer])
# Note: We will stack ans[layer][s][] later to get ans[layer][s]
else:
ans.append([s.unsqueeze(1) for s in layer])
for b, states in enumerate(state_list[1:], 1):
for li, layer in enumerate(states):
for si, s in enumerate(layer):
ans[li][si].append(s)
if b == batch_size - 1:
ans[li][si] = torch.stack(ans[li][si], dim=1)
# We will use unbind(dim=1) later in unstack_states()
return ans
class EmformerAttention(nn.Module):
r"""Emformer layer attention module.
@ -424,9 +493,9 @@ class EmformerAttention(nn.Module):
# key, value: [memory, right context, left context, uttrance]
KV = (
memory.size(0)
+ right_context.size(0)
+ left_context_key.size(0)
+ utterance.size(0)
+ right_context.size(0) # noqa
+ left_context_key.size(0) # noqa
+ utterance.size(0) # noqa
)
attention_mask = torch.zeros(Q, KV).to(
dtype=torch.bool, device=utterance.device

View File

@ -682,6 +682,44 @@ def test_emformer_infer_batch_single_consistency():
assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0)
def test_emformer_infer_states_stack():
from emformer import Emformer, unstack_states, stack_states
num_features = 80
output_dim = 1000
chunk_length = 8
U = chunk_length
L, R = 128, 4
B, D = 2, 256
num_encoder_layers = 2
for use_memory in [True, False]:
if use_memory:
M = 3
else:
M = 0
model = Emformer(
num_features=num_features,
output_dim=output_dim,
chunk_length=chunk_length,
subsampling_factor=4,
d_model=D,
num_encoder_layers=num_encoder_layers,
left_context_length=L,
right_context_length=R,
max_memory_size=M,
vgg_frontend=False,
)
x = torch.randn(B, U + R + 3, num_features)
x_lens = torch.full((B, ), U + R + 3)
logits, output_lengths, states = model.infer(x, x_lens,)
states2 = stack_states(unstack_states(states))
for ss, ss2 in zip(states, states2):
for s, s2 in zip(ss, ss2):
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
if __name__ == "__main__":
test_emformer_attention_forward()
test_emformer_attention_infer()
@ -695,3 +733,4 @@ if __name__ == "__main__":
test_emformer_layer_forward_infer_consistency()
test_emformer_encoder_forward_infer_consistency()
test_emformer_infer_batch_single_consistency()
test_emformer_infer_states_stack()