test states stack and unstack

This commit is contained in:
yaozengwei 2022-06-09 22:17:03 +08:00
parent 7f09720403
commit 507d7c13f4
3 changed files with 76 additions and 21 deletions

View File

@ -557,7 +557,7 @@ class HypothesisList(object):
return ", ".join(s)
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps].
Args:
@ -648,7 +648,7 @@ def modified_beam_search(
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = _get_hyps_shape(B).to(device)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]

View File

@ -302,6 +302,7 @@ def decode_one_chunk(
# update cached states of each stream
state_list = unstack_states(states)
assert len(streams) == len(state_list)
for i, s in enumerate(state_list):
streams[i].states = s
@ -358,15 +359,10 @@ def decode_dataset(
"""
device = next(model.parameters()).device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 300
fbank = create_streaming_feature_extractor()
decode_results = []
streams = []
for num, cut in enumerate(cuts):
@ -382,7 +378,6 @@ def decode_dataset(
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
fbank = create_streaming_feature_extractor()
feature = fbank(samples)
stream.set_feature(feature)
stream.set_ground_truth(cut.supervisions[0].text)

View File

@ -511,15 +511,75 @@ def test_emformer_infer():
assert conv_cache.shape == (B, D, kernel_size - 1)
def test_state_stack_unstack():
from emformer import Emformer, stack_states, unstack_states
num_features = 80
chunk_length = 32
encoder_dim = 512
num_encoder_layers = 2
kernel_size = 31
left_context_length = 32
right_context_length = 8
memory_size = 32
batch_size = 2
model = Emformer(
num_features=num_features,
chunk_length=chunk_length,
subsampling_factor=4,
d_model=encoder_dim,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
memory_size=memory_size,
)
attn_caches = [
[
torch.zeros(memory_size, batch_size, encoder_dim),
torch.zeros(left_context_length // 4, batch_size, encoder_dim),
torch.zeros(
left_context_length // 4,
batch_size,
encoder_dim,
),
]
for _ in range(num_encoder_layers)
]
conv_caches = [
torch.zeros(batch_size, encoder_dim, kernel_size - 1)
for _ in range(num_encoder_layers)
]
states = [attn_caches, conv_caches]
x = torch.randn(batch_size, 23, num_features)
x_lens = torch.full((batch_size,), 23)
num_processed_frames = torch.full((batch_size,), 0)
y, y_lens, states = model.infer(
x, x_lens, num_processed_frames=num_processed_frames, states=states
)
state_list = unstack_states(states)
states2 = stack_states(state_list)
for ss, ss2 in zip(states[0], states2[0]):
for s, s2 in zip(ss, ss2):
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
for s, s2 in zip(states[1], states2[1]):
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
if __name__ == "__main__":
test_emformer_attention_forward()
test_emformer_attention_infer()
test_convolution_module_forward()
test_convolution_module_infer()
test_emformer_encoder_layer_forward()
test_emformer_encoder_layer_infer()
test_emformer_encoder_forward()
test_emformer_encoder_infer()
test_emformer_encoder_forward_infer_consistency()
test_emformer_forward()
test_emformer_infer()
# test_emformer_attention_forward()
# test_emformer_attention_infer()
# test_convolution_module_forward()
# test_convolution_module_infer()
# test_emformer_encoder_layer_forward()
# test_emformer_encoder_layer_infer()
# test_emformer_encoder_forward()
# test_emformer_encoder_infer()
# test_emformer_encoder_forward_infer_consistency()
# test_emformer_forward()
# test_emformer_infer()
test_state_stack_unstack()