mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
update test_emformer.py
This commit is contained in:
parent
d58002c414
commit
3e131891a2
@ -61,16 +61,20 @@ def test_emformer_attention_infer():
|
||||
left_context_key = torch.randn(L, B, D)
|
||||
left_context_val = torch.randn(L, B, D)
|
||||
|
||||
output_right_context_utterance, output_memory, next_key, next_val = \
|
||||
attention.infer(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
summary,
|
||||
memory,
|
||||
left_context_key,
|
||||
left_context_val,
|
||||
)
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
next_key,
|
||||
next_val,
|
||||
) = attention.infer(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
summary,
|
||||
memory,
|
||||
left_context_key,
|
||||
left_context_val,
|
||||
)
|
||||
assert output_right_context_utterance.shape == (R + U, B, D)
|
||||
assert output_memory.shape == (S, B, D)
|
||||
assert next_key.shape == (L + U, B, D)
|
||||
@ -98,7 +102,7 @@ def test_emformer_layer_forward():
|
||||
chunk_length=chunk_length,
|
||||
left_context_length=L,
|
||||
max_memory_size=M,
|
||||
)
|
||||
)
|
||||
|
||||
Q, KV = R + U + S, M + R + U
|
||||
utterance = torch.randn(U, B, D)
|
||||
@ -141,7 +145,7 @@ def test_emformer_layer_infer():
|
||||
chunk_length=chunk_length,
|
||||
left_context_length=L,
|
||||
max_memory_size=M,
|
||||
)
|
||||
)
|
||||
|
||||
utterance = torch.randn(U, B, D)
|
||||
lengths = torch.randint(1, U + 1, (B,))
|
||||
@ -149,14 +153,18 @@ def test_emformer_layer_infer():
|
||||
right_context = torch.randn(R, B, D)
|
||||
memory = torch.randn(M, B, D)
|
||||
state = None
|
||||
output_utterance, output_right_context, output_memory, output_state = \
|
||||
layer.infer(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
memory,
|
||||
state,
|
||||
)
|
||||
(
|
||||
output_utterance,
|
||||
output_right_context,
|
||||
output_memory,
|
||||
output_state,
|
||||
) = layer.infer(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
memory,
|
||||
state,
|
||||
)
|
||||
assert output_utterance.shape == (U, B, D)
|
||||
assert output_right_context.shape == (R, B, D)
|
||||
if use_memory:
|
||||
@ -200,9 +208,7 @@ def test_emformer_encoder_forward():
|
||||
|
||||
output, output_lengths = encoder(x, lengths)
|
||||
assert output.shape == (U, B, D)
|
||||
assert torch.equal(
|
||||
output_lengths, torch.clamp(lengths - R, min=0)
|
||||
)
|
||||
assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0))
|
||||
|
||||
|
||||
def test_emformer_encoder_infer():
|
||||
@ -236,8 +242,7 @@ def test_emformer_encoder_infer():
|
||||
x = torch.randn(U + R, B, D)
|
||||
lengths = torch.randint(1, U + R + 1, (B,))
|
||||
lengths[0] = U + R
|
||||
output, output_lengths, states = \
|
||||
encoder.infer(x, lengths, states)
|
||||
output, output_lengths, states = encoder.infer(x, lengths, states)
|
||||
assert output.shape == (U, B, D)
|
||||
assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0))
|
||||
assert len(states) == num_encoder_layers
|
||||
@ -253,6 +258,7 @@ def test_emformer_encoder_infer():
|
||||
|
||||
def test_emformer_forward():
|
||||
from emformer import Emformer
|
||||
|
||||
num_features = 80
|
||||
output_dim = 1000
|
||||
chunk_length = 8
|
||||
@ -281,12 +287,13 @@ def test_emformer_forward():
|
||||
assert logits.shape == (B, U // 4, output_dim)
|
||||
assert torch.equal(
|
||||
output_lengths,
|
||||
torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0)
|
||||
torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0),
|
||||
)
|
||||
|
||||
|
||||
def test_emformer_infer():
|
||||
from emformer import Emformer
|
||||
|
||||
num_features = 80
|
||||
output_dim = 1000
|
||||
chunk_length = 8
|
||||
@ -317,12 +324,11 @@ def test_emformer_infer():
|
||||
x = torch.randn(B, U + R + 3, num_features)
|
||||
x_lens = torch.randint(1, U + R + 3 + 1, (B,))
|
||||
x_lens[0] = U + R + 3
|
||||
logits, output_lengths, states = \
|
||||
model.infer(x, x_lens, states)
|
||||
logits, output_lengths, states = model.infer(x, x_lens, states)
|
||||
assert logits.shape == (B, U // 4, output_dim)
|
||||
assert torch.equal(
|
||||
output_lengths,
|
||||
torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0)
|
||||
torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0),
|
||||
)
|
||||
assert len(states) == num_encoder_layers
|
||||
for state in states:
|
||||
@ -332,7 +338,7 @@ def test_emformer_infer():
|
||||
assert state[2].shape == (L // 4, B, D)
|
||||
assert torch.equal(
|
||||
state[3],
|
||||
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3])
|
||||
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]),
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user