update test_emformer.py

This commit is contained in:
yaozengwei 2022-04-08 20:43:54 +08:00
parent d58002c414
commit 3e131891a2

View File

@ -61,16 +61,20 @@ def test_emformer_attention_infer():
left_context_key = torch.randn(L, B, D)
left_context_val = torch.randn(L, B, D)
output_right_context_utterance, output_memory, next_key, next_val = \
attention.infer(
utterance,
lengths,
right_context,
summary,
memory,
left_context_key,
left_context_val,
)
(
output_right_context_utterance,
output_memory,
next_key,
next_val,
) = attention.infer(
utterance,
lengths,
right_context,
summary,
memory,
left_context_key,
left_context_val,
)
assert output_right_context_utterance.shape == (R + U, B, D)
assert output_memory.shape == (S, B, D)
assert next_key.shape == (L + U, B, D)
@ -98,7 +102,7 @@ def test_emformer_layer_forward():
chunk_length=chunk_length,
left_context_length=L,
max_memory_size=M,
)
)
Q, KV = R + U + S, M + R + U
utterance = torch.randn(U, B, D)
@ -141,7 +145,7 @@ def test_emformer_layer_infer():
chunk_length=chunk_length,
left_context_length=L,
max_memory_size=M,
)
)
utterance = torch.randn(U, B, D)
lengths = torch.randint(1, U + 1, (B,))
@ -149,14 +153,18 @@ def test_emformer_layer_infer():
right_context = torch.randn(R, B, D)
memory = torch.randn(M, B, D)
state = None
output_utterance, output_right_context, output_memory, output_state = \
layer.infer(
utterance,
lengths,
right_context,
memory,
state,
)
(
output_utterance,
output_right_context,
output_memory,
output_state,
) = layer.infer(
utterance,
lengths,
right_context,
memory,
state,
)
assert output_utterance.shape == (U, B, D)
assert output_right_context.shape == (R, B, D)
if use_memory:
@ -200,9 +208,7 @@ def test_emformer_encoder_forward():
output, output_lengths = encoder(x, lengths)
assert output.shape == (U, B, D)
assert torch.equal(
output_lengths, torch.clamp(lengths - R, min=0)
)
assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0))
def test_emformer_encoder_infer():
@ -236,8 +242,7 @@ def test_emformer_encoder_infer():
x = torch.randn(U + R, B, D)
lengths = torch.randint(1, U + R + 1, (B,))
lengths[0] = U + R
output, output_lengths, states = \
encoder.infer(x, lengths, states)
output, output_lengths, states = encoder.infer(x, lengths, states)
assert output.shape == (U, B, D)
assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0))
assert len(states) == num_encoder_layers
@ -253,6 +258,7 @@ def test_emformer_encoder_infer():
def test_emformer_forward():
from emformer import Emformer
num_features = 80
output_dim = 1000
chunk_length = 8
@ -281,12 +287,13 @@ def test_emformer_forward():
assert logits.shape == (B, U // 4, output_dim)
assert torch.equal(
output_lengths,
torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0)
torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0),
)
def test_emformer_infer():
from emformer import Emformer
num_features = 80
output_dim = 1000
chunk_length = 8
@ -317,12 +324,11 @@ def test_emformer_infer():
x = torch.randn(B, U + R + 3, num_features)
x_lens = torch.randint(1, U + R + 3 + 1, (B,))
x_lens[0] = U + R + 3
logits, output_lengths, states = \
model.infer(x, x_lens, states)
logits, output_lengths, states = model.infer(x, x_lens, states)
assert logits.shape == (B, U // 4, output_dim)
assert torch.equal(
output_lengths,
torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0)
torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0),
)
assert len(states) == num_encoder_layers
for state in states:
@ -332,7 +338,7 @@ def test_emformer_infer():
assert state[2].shape == (L // 4, B, D)
assert torch.equal(
state[3],
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3])
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]),
)