mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
test states stack and unstack
This commit is contained in:
parent
7f09720403
commit
507d7c13f4
@ -557,7 +557,7 @@ class HypothesisList(object):
|
||||
return ", ".join(s)
|
||||
|
||||
|
||||
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||
def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||
"""Return a ragged shape with axes [utt][num_hyps].
|
||||
|
||||
Args:
|
||||
@ -648,7 +648,7 @@ def modified_beam_search(
|
||||
finalized_B = B[batch_size:] + finalized_B
|
||||
B = B[:batch_size]
|
||||
|
||||
hyps_shape = _get_hyps_shape(B).to(device)
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
@ -302,6 +302,7 @@ def decode_one_chunk(
|
||||
|
||||
# update cached states of each stream
|
||||
state_list = unstack_states(states)
|
||||
assert len(streams) == len(state_list)
|
||||
for i, s in enumerate(state_list):
|
||||
streams[i].states = s
|
||||
|
||||
@ -358,15 +359,10 @@ def decode_dataset(
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
|
||||
opts = FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
log_interval = 300
|
||||
|
||||
fbank = create_streaming_feature_extractor()
|
||||
|
||||
decode_results = []
|
||||
streams = []
|
||||
for num, cut in enumerate(cuts):
|
||||
@ -382,7 +378,6 @@ def decode_dataset(
|
||||
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||
|
||||
samples = torch.from_numpy(audio).squeeze(0)
|
||||
fbank = create_streaming_feature_extractor()
|
||||
feature = fbank(samples)
|
||||
stream.set_feature(feature)
|
||||
stream.set_ground_truth(cut.supervisions[0].text)
|
||||
|
@ -511,15 +511,75 @@ def test_emformer_infer():
|
||||
assert conv_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
def test_state_stack_unstack():
|
||||
from emformer import Emformer, stack_states, unstack_states
|
||||
|
||||
num_features = 80
|
||||
chunk_length = 32
|
||||
encoder_dim = 512
|
||||
num_encoder_layers = 2
|
||||
kernel_size = 31
|
||||
left_context_length = 32
|
||||
right_context_length = 8
|
||||
memory_size = 32
|
||||
batch_size = 2
|
||||
|
||||
model = Emformer(
|
||||
num_features=num_features,
|
||||
chunk_length=chunk_length,
|
||||
subsampling_factor=4,
|
||||
d_model=encoder_dim,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
memory_size=memory_size,
|
||||
)
|
||||
attn_caches = [
|
||||
[
|
||||
torch.zeros(memory_size, batch_size, encoder_dim),
|
||||
torch.zeros(left_context_length // 4, batch_size, encoder_dim),
|
||||
torch.zeros(
|
||||
left_context_length // 4,
|
||||
batch_size,
|
||||
encoder_dim,
|
||||
),
|
||||
]
|
||||
for _ in range(num_encoder_layers)
|
||||
]
|
||||
conv_caches = [
|
||||
torch.zeros(batch_size, encoder_dim, kernel_size - 1)
|
||||
for _ in range(num_encoder_layers)
|
||||
]
|
||||
states = [attn_caches, conv_caches]
|
||||
x = torch.randn(batch_size, 23, num_features)
|
||||
x_lens = torch.full((batch_size,), 23)
|
||||
num_processed_frames = torch.full((batch_size,), 0)
|
||||
y, y_lens, states = model.infer(
|
||||
x, x_lens, num_processed_frames=num_processed_frames, states=states
|
||||
)
|
||||
|
||||
state_list = unstack_states(states)
|
||||
states2 = stack_states(state_list)
|
||||
|
||||
for ss, ss2 in zip(states[0], states2[0]):
|
||||
for s, s2 in zip(ss, ss2):
|
||||
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
|
||||
|
||||
for s, s2 in zip(states[1], states2[1]):
|
||||
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_emformer_attention_forward()
|
||||
test_emformer_attention_infer()
|
||||
test_convolution_module_forward()
|
||||
test_convolution_module_infer()
|
||||
test_emformer_encoder_layer_forward()
|
||||
test_emformer_encoder_layer_infer()
|
||||
test_emformer_encoder_forward()
|
||||
test_emformer_encoder_infer()
|
||||
test_emformer_encoder_forward_infer_consistency()
|
||||
test_emformer_forward()
|
||||
test_emformer_infer()
|
||||
# test_emformer_attention_forward()
|
||||
# test_emformer_attention_infer()
|
||||
# test_convolution_module_forward()
|
||||
# test_convolution_module_infer()
|
||||
# test_emformer_encoder_layer_forward()
|
||||
# test_emformer_encoder_layer_infer()
|
||||
# test_emformer_encoder_forward()
|
||||
# test_emformer_encoder_infer()
|
||||
# test_emformer_encoder_forward_infer_consistency()
|
||||
# test_emformer_forward()
|
||||
# test_emformer_infer()
|
||||
test_state_stack_unstack()
|
||||
|
Loading…
x
Reference in New Issue
Block a user