update test_emformer.py

This commit is contained in:
yaozengwei 2022-04-08 20:43:54 +08:00
parent d58002c414
commit 3e131891a2

View File

@ -61,16 +61,20 @@ def test_emformer_attention_infer():
left_context_key = torch.randn(L, B, D) left_context_key = torch.randn(L, B, D)
left_context_val = torch.randn(L, B, D) left_context_val = torch.randn(L, B, D)
output_right_context_utterance, output_memory, next_key, next_val = \ (
attention.infer( output_right_context_utterance,
utterance, output_memory,
lengths, next_key,
right_context, next_val,
summary, ) = attention.infer(
memory, utterance,
left_context_key, lengths,
left_context_val, right_context,
) summary,
memory,
left_context_key,
left_context_val,
)
assert output_right_context_utterance.shape == (R + U, B, D) assert output_right_context_utterance.shape == (R + U, B, D)
assert output_memory.shape == (S, B, D) assert output_memory.shape == (S, B, D)
assert next_key.shape == (L + U, B, D) assert next_key.shape == (L + U, B, D)
@ -98,7 +102,7 @@ def test_emformer_layer_forward():
chunk_length=chunk_length, chunk_length=chunk_length,
left_context_length=L, left_context_length=L,
max_memory_size=M, max_memory_size=M,
) )
Q, KV = R + U + S, M + R + U Q, KV = R + U + S, M + R + U
utterance = torch.randn(U, B, D) utterance = torch.randn(U, B, D)
@ -141,7 +145,7 @@ def test_emformer_layer_infer():
chunk_length=chunk_length, chunk_length=chunk_length,
left_context_length=L, left_context_length=L,
max_memory_size=M, max_memory_size=M,
) )
utterance = torch.randn(U, B, D) utterance = torch.randn(U, B, D)
lengths = torch.randint(1, U + 1, (B,)) lengths = torch.randint(1, U + 1, (B,))
@ -149,14 +153,18 @@ def test_emformer_layer_infer():
right_context = torch.randn(R, B, D) right_context = torch.randn(R, B, D)
memory = torch.randn(M, B, D) memory = torch.randn(M, B, D)
state = None state = None
output_utterance, output_right_context, output_memory, output_state = \ (
layer.infer( output_utterance,
utterance, output_right_context,
lengths, output_memory,
right_context, output_state,
memory, ) = layer.infer(
state, utterance,
) lengths,
right_context,
memory,
state,
)
assert output_utterance.shape == (U, B, D) assert output_utterance.shape == (U, B, D)
assert output_right_context.shape == (R, B, D) assert output_right_context.shape == (R, B, D)
if use_memory: if use_memory:
@ -200,9 +208,7 @@ def test_emformer_encoder_forward():
output, output_lengths = encoder(x, lengths) output, output_lengths = encoder(x, lengths)
assert output.shape == (U, B, D) assert output.shape == (U, B, D)
assert torch.equal( assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0))
output_lengths, torch.clamp(lengths - R, min=0)
)
def test_emformer_encoder_infer(): def test_emformer_encoder_infer():
@ -236,8 +242,7 @@ def test_emformer_encoder_infer():
x = torch.randn(U + R, B, D) x = torch.randn(U + R, B, D)
lengths = torch.randint(1, U + R + 1, (B,)) lengths = torch.randint(1, U + R + 1, (B,))
lengths[0] = U + R lengths[0] = U + R
output, output_lengths, states = \ output, output_lengths, states = encoder.infer(x, lengths, states)
encoder.infer(x, lengths, states)
assert output.shape == (U, B, D) assert output.shape == (U, B, D)
assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0)) assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0))
assert len(states) == num_encoder_layers assert len(states) == num_encoder_layers
@ -253,6 +258,7 @@ def test_emformer_encoder_infer():
def test_emformer_forward(): def test_emformer_forward():
from emformer import Emformer from emformer import Emformer
num_features = 80 num_features = 80
output_dim = 1000 output_dim = 1000
chunk_length = 8 chunk_length = 8
@ -281,12 +287,13 @@ def test_emformer_forward():
assert logits.shape == (B, U // 4, output_dim) assert logits.shape == (B, U // 4, output_dim)
assert torch.equal( assert torch.equal(
output_lengths, output_lengths,
torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0) torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0),
) )
def test_emformer_infer(): def test_emformer_infer():
from emformer import Emformer from emformer import Emformer
num_features = 80 num_features = 80
output_dim = 1000 output_dim = 1000
chunk_length = 8 chunk_length = 8
@ -317,12 +324,11 @@ def test_emformer_infer():
x = torch.randn(B, U + R + 3, num_features) x = torch.randn(B, U + R + 3, num_features)
x_lens = torch.randint(1, U + R + 3 + 1, (B,)) x_lens = torch.randint(1, U + R + 3 + 1, (B,))
x_lens[0] = U + R + 3 x_lens[0] = U + R + 3
logits, output_lengths, states = \ logits, output_lengths, states = model.infer(x, x_lens, states)
model.infer(x, x_lens, states)
assert logits.shape == (B, U // 4, output_dim) assert logits.shape == (B, U // 4, output_dim)
assert torch.equal( assert torch.equal(
output_lengths, output_lengths,
torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0) torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0),
) )
assert len(states) == num_encoder_layers assert len(states) == num_encoder_layers
for state in states: for state in states:
@ -332,7 +338,7 @@ def test_emformer_infer():
assert state[2].shape == (L // 4, B, D) assert state[2].shape == (L // 4, B, D)
assert torch.equal( assert torch.equal(
state[3], state[3],
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]) U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]),
) )