mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
Add test functions for torchaudio emformer codes.
This commit is contained in:
parent
524f3aa015
commit
32420cc3e4
@ -65,8 +65,135 @@ def test_emformer():
|
||||
print(f"Number of encoder parameters: {num_param}")
|
||||
|
||||
|
||||
def test_emformer_infer_batch_single_consistency():
|
||||
"""Test consistency of cached states and output logits between single
|
||||
utterance inference and batch inference."""
|
||||
from emformer import Emformer
|
||||
|
||||
num_features = 80
|
||||
output_dim = 1000
|
||||
chunk_length = 8
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
L, R = 128, 4
|
||||
B, D = 2, 256
|
||||
num_encoder_layers = 4
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
M = 3
|
||||
else:
|
||||
M = 0
|
||||
model = Emformer(
|
||||
num_features=num_features,
|
||||
output_dim=output_dim,
|
||||
segment_length=chunk_length,
|
||||
subsampling_factor=4,
|
||||
d_model=D,
|
||||
nhead=4,
|
||||
dim_feedforward=1024,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
left_context_length=L,
|
||||
right_context_length=R,
|
||||
max_memory_size=M,
|
||||
vgg_frontend=False,
|
||||
)
|
||||
model.eval()
|
||||
|
||||
def save_states(states):
|
||||
saved_states = []
|
||||
for layer_idx in range(len(states)):
|
||||
layer_state = []
|
||||
layer_state.append(states[layer_idx][0].clone()) # memory
|
||||
layer_state.append(
|
||||
states[layer_idx][1].clone()
|
||||
) # left_context_key
|
||||
layer_state.append(
|
||||
states[layer_idx][2].clone()
|
||||
) # left_context_val
|
||||
layer_state.append(states[layer_idx][3].clone()) # past_length
|
||||
saved_states.append(layer_state)
|
||||
return saved_states
|
||||
|
||||
def assert_states_equal(saved_states, states, sample_idx):
|
||||
for layer_idx in range(len(saved_states)):
|
||||
# assert eqaul memory
|
||||
assert torch.allclose(
|
||||
states[layer_idx][0],
|
||||
saved_states[layer_idx][0][
|
||||
:, sample_idx : sample_idx + 1 # noqa
|
||||
],
|
||||
atol=1e-5,
|
||||
rtol=0.0,
|
||||
)
|
||||
# assert equal left_context_key
|
||||
assert torch.allclose(
|
||||
states[layer_idx][1],
|
||||
saved_states[layer_idx][1][
|
||||
:, sample_idx : sample_idx + 1 # noqa
|
||||
],
|
||||
atol=1e-5,
|
||||
rtol=0.0,
|
||||
)
|
||||
# assert equal left_context_val
|
||||
assert torch.allclose(
|
||||
states[layer_idx][2],
|
||||
saved_states[layer_idx][2][
|
||||
:, sample_idx : sample_idx + 1 # noqa
|
||||
],
|
||||
atol=1e-5,
|
||||
rtol=0.0,
|
||||
)
|
||||
# assert eqaul past_length
|
||||
assert torch.equal(
|
||||
states[layer_idx][3],
|
||||
saved_states[layer_idx][3][
|
||||
:, sample_idx : sample_idx + 1 # noqa
|
||||
],
|
||||
)
|
||||
|
||||
x = torch.randn(B, U + R + 3, num_features)
|
||||
batch_logits = []
|
||||
batch_states = []
|
||||
states = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
start_idx = chunk_idx * chunk_length
|
||||
end_idx = start_idx + chunk_length
|
||||
chunk = x[:, start_idx : end_idx + R + 3] # noqa
|
||||
lengths = torch.tensor([chunk_length + R + 3]).expand(B)
|
||||
logits, output_lengths, states = model.streaming_forward(
|
||||
chunk, lengths, states
|
||||
)
|
||||
batch_logits.append(logits)
|
||||
batch_states.append(save_states(states))
|
||||
batch_logits = torch.cat(batch_logits, dim=1)
|
||||
|
||||
single_logits = []
|
||||
for sample_idx in range(B):
|
||||
sample = x[sample_idx : sample_idx + 1] # noqa
|
||||
chunk_logits = []
|
||||
states = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
start_idx = chunk_idx * chunk_length
|
||||
end_idx = start_idx + chunk_length
|
||||
chunk = sample[:, start_idx : end_idx + R + 3] # noqa
|
||||
lengths = torch.tensor([chunk_length + R + 3])
|
||||
logits, output_lengths, states = model.streaming_forward(
|
||||
chunk, lengths, states
|
||||
)
|
||||
chunk_logits.append(logits)
|
||||
|
||||
assert_states_equal(batch_states[chunk_idx], states, sample_idx)
|
||||
|
||||
chunk_logits = torch.cat(chunk_logits, dim=1)
|
||||
single_logits.append(chunk_logits)
|
||||
single_logits = torch.cat(single_logits, dim=0)
|
||||
|
||||
assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0)
|
||||
|
||||
|
||||
def main():
|
||||
test_emformer()
|
||||
test_emformer_infer_batch_single_consistency()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user