test states stack and unstack
This commit is contained in:
parent
7f09720403
commit
507d7c13f4
@ -557,7 +557,7 @@ class HypothesisList(object):
|
|||||||
return ", ".join(s)
|
return ", ".join(s)
|
||||||
|
|
||||||
|
|
||||||
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||||
"""Return a ragged shape with axes [utt][num_hyps].
|
"""Return a ragged shape with axes [utt][num_hyps].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -648,7 +648,7 @@ def modified_beam_search(
|
|||||||
finalized_B = B[batch_size:] + finalized_B
|
finalized_B = B[batch_size:] + finalized_B
|
||||||
B = B[:batch_size]
|
B = B[:batch_size]
|
||||||
|
|
||||||
hyps_shape = _get_hyps_shape(B).to(device)
|
hyps_shape = get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
A = [list(b) for b in B]
|
A = [list(b) for b in B]
|
||||||
B = [HypothesisList() for _ in range(batch_size)]
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
|||||||
@ -302,6 +302,7 @@ def decode_one_chunk(
|
|||||||
|
|
||||||
# update cached states of each stream
|
# update cached states of each stream
|
||||||
state_list = unstack_states(states)
|
state_list = unstack_states(states)
|
||||||
|
assert len(streams) == len(state_list)
|
||||||
for i, s in enumerate(state_list):
|
for i, s in enumerate(state_list):
|
||||||
streams[i].states = s
|
streams[i].states = s
|
||||||
|
|
||||||
@ -358,15 +359,10 @@ def decode_dataset(
|
|||||||
"""
|
"""
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
opts = FbankOptions()
|
|
||||||
opts.device = device
|
|
||||||
opts.frame_opts.dither = 0
|
|
||||||
opts.frame_opts.snip_edges = False
|
|
||||||
opts.frame_opts.samp_freq = 16000
|
|
||||||
opts.mel_opts.num_bins = 80
|
|
||||||
|
|
||||||
log_interval = 300
|
log_interval = 300
|
||||||
|
|
||||||
|
fbank = create_streaming_feature_extractor()
|
||||||
|
|
||||||
decode_results = []
|
decode_results = []
|
||||||
streams = []
|
streams = []
|
||||||
for num, cut in enumerate(cuts):
|
for num, cut in enumerate(cuts):
|
||||||
@ -382,7 +378,6 @@ def decode_dataset(
|
|||||||
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||||
|
|
||||||
samples = torch.from_numpy(audio).squeeze(0)
|
samples = torch.from_numpy(audio).squeeze(0)
|
||||||
fbank = create_streaming_feature_extractor()
|
|
||||||
feature = fbank(samples)
|
feature = fbank(samples)
|
||||||
stream.set_feature(feature)
|
stream.set_feature(feature)
|
||||||
stream.set_ground_truth(cut.supervisions[0].text)
|
stream.set_ground_truth(cut.supervisions[0].text)
|
||||||
|
|||||||
@ -511,15 +511,75 @@ def test_emformer_infer():
|
|||||||
assert conv_cache.shape == (B, D, kernel_size - 1)
|
assert conv_cache.shape == (B, D, kernel_size - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_stack_unstack():
|
||||||
|
from emformer import Emformer, stack_states, unstack_states
|
||||||
|
|
||||||
|
num_features = 80
|
||||||
|
chunk_length = 32
|
||||||
|
encoder_dim = 512
|
||||||
|
num_encoder_layers = 2
|
||||||
|
kernel_size = 31
|
||||||
|
left_context_length = 32
|
||||||
|
right_context_length = 8
|
||||||
|
memory_size = 32
|
||||||
|
batch_size = 2
|
||||||
|
|
||||||
|
model = Emformer(
|
||||||
|
num_features=num_features,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
subsampling_factor=4,
|
||||||
|
d_model=encoder_dim,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
cnn_module_kernel=kernel_size,
|
||||||
|
left_context_length=left_context_length,
|
||||||
|
right_context_length=right_context_length,
|
||||||
|
memory_size=memory_size,
|
||||||
|
)
|
||||||
|
attn_caches = [
|
||||||
|
[
|
||||||
|
torch.zeros(memory_size, batch_size, encoder_dim),
|
||||||
|
torch.zeros(left_context_length // 4, batch_size, encoder_dim),
|
||||||
|
torch.zeros(
|
||||||
|
left_context_length // 4,
|
||||||
|
batch_size,
|
||||||
|
encoder_dim,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for _ in range(num_encoder_layers)
|
||||||
|
]
|
||||||
|
conv_caches = [
|
||||||
|
torch.zeros(batch_size, encoder_dim, kernel_size - 1)
|
||||||
|
for _ in range(num_encoder_layers)
|
||||||
|
]
|
||||||
|
states = [attn_caches, conv_caches]
|
||||||
|
x = torch.randn(batch_size, 23, num_features)
|
||||||
|
x_lens = torch.full((batch_size,), 23)
|
||||||
|
num_processed_frames = torch.full((batch_size,), 0)
|
||||||
|
y, y_lens, states = model.infer(
|
||||||
|
x, x_lens, num_processed_frames=num_processed_frames, states=states
|
||||||
|
)
|
||||||
|
|
||||||
|
state_list = unstack_states(states)
|
||||||
|
states2 = stack_states(state_list)
|
||||||
|
|
||||||
|
for ss, ss2 in zip(states[0], states2[0]):
|
||||||
|
for s, s2 in zip(ss, ss2):
|
||||||
|
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
|
||||||
|
|
||||||
|
for s, s2 in zip(states[1], states2[1]):
|
||||||
|
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_emformer_attention_forward()
|
# test_emformer_attention_forward()
|
||||||
test_emformer_attention_infer()
|
# test_emformer_attention_infer()
|
||||||
test_convolution_module_forward()
|
# test_convolution_module_forward()
|
||||||
test_convolution_module_infer()
|
# test_convolution_module_infer()
|
||||||
test_emformer_encoder_layer_forward()
|
# test_emformer_encoder_layer_forward()
|
||||||
test_emformer_encoder_layer_infer()
|
# test_emformer_encoder_layer_infer()
|
||||||
test_emformer_encoder_forward()
|
# test_emformer_encoder_forward()
|
||||||
test_emformer_encoder_infer()
|
# test_emformer_encoder_infer()
|
||||||
test_emformer_encoder_forward_infer_consistency()
|
# test_emformer_encoder_forward_infer_consistency()
|
||||||
test_emformer_forward()
|
# test_emformer_forward()
|
||||||
test_emformer_infer()
|
# test_emformer_infer()
|
||||||
|
test_state_stack_unstack()
|
||||||
|
|||||||
Reference in New Issue
Block a user