mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
update test_emformer.py
This commit is contained in:
parent
d58002c414
commit
3e131891a2
@ -61,16 +61,20 @@ def test_emformer_attention_infer():
|
|||||||
left_context_key = torch.randn(L, B, D)
|
left_context_key = torch.randn(L, B, D)
|
||||||
left_context_val = torch.randn(L, B, D)
|
left_context_val = torch.randn(L, B, D)
|
||||||
|
|
||||||
output_right_context_utterance, output_memory, next_key, next_val = \
|
(
|
||||||
attention.infer(
|
output_right_context_utterance,
|
||||||
utterance,
|
output_memory,
|
||||||
lengths,
|
next_key,
|
||||||
right_context,
|
next_val,
|
||||||
summary,
|
) = attention.infer(
|
||||||
memory,
|
utterance,
|
||||||
left_context_key,
|
lengths,
|
||||||
left_context_val,
|
right_context,
|
||||||
)
|
summary,
|
||||||
|
memory,
|
||||||
|
left_context_key,
|
||||||
|
left_context_val,
|
||||||
|
)
|
||||||
assert output_right_context_utterance.shape == (R + U, B, D)
|
assert output_right_context_utterance.shape == (R + U, B, D)
|
||||||
assert output_memory.shape == (S, B, D)
|
assert output_memory.shape == (S, B, D)
|
||||||
assert next_key.shape == (L + U, B, D)
|
assert next_key.shape == (L + U, B, D)
|
||||||
@ -98,7 +102,7 @@ def test_emformer_layer_forward():
|
|||||||
chunk_length=chunk_length,
|
chunk_length=chunk_length,
|
||||||
left_context_length=L,
|
left_context_length=L,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
)
|
)
|
||||||
|
|
||||||
Q, KV = R + U + S, M + R + U
|
Q, KV = R + U + S, M + R + U
|
||||||
utterance = torch.randn(U, B, D)
|
utterance = torch.randn(U, B, D)
|
||||||
@ -141,7 +145,7 @@ def test_emformer_layer_infer():
|
|||||||
chunk_length=chunk_length,
|
chunk_length=chunk_length,
|
||||||
left_context_length=L,
|
left_context_length=L,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
)
|
)
|
||||||
|
|
||||||
utterance = torch.randn(U, B, D)
|
utterance = torch.randn(U, B, D)
|
||||||
lengths = torch.randint(1, U + 1, (B,))
|
lengths = torch.randint(1, U + 1, (B,))
|
||||||
@ -149,14 +153,18 @@ def test_emformer_layer_infer():
|
|||||||
right_context = torch.randn(R, B, D)
|
right_context = torch.randn(R, B, D)
|
||||||
memory = torch.randn(M, B, D)
|
memory = torch.randn(M, B, D)
|
||||||
state = None
|
state = None
|
||||||
output_utterance, output_right_context, output_memory, output_state = \
|
(
|
||||||
layer.infer(
|
output_utterance,
|
||||||
utterance,
|
output_right_context,
|
||||||
lengths,
|
output_memory,
|
||||||
right_context,
|
output_state,
|
||||||
memory,
|
) = layer.infer(
|
||||||
state,
|
utterance,
|
||||||
)
|
lengths,
|
||||||
|
right_context,
|
||||||
|
memory,
|
||||||
|
state,
|
||||||
|
)
|
||||||
assert output_utterance.shape == (U, B, D)
|
assert output_utterance.shape == (U, B, D)
|
||||||
assert output_right_context.shape == (R, B, D)
|
assert output_right_context.shape == (R, B, D)
|
||||||
if use_memory:
|
if use_memory:
|
||||||
@ -200,9 +208,7 @@ def test_emformer_encoder_forward():
|
|||||||
|
|
||||||
output, output_lengths = encoder(x, lengths)
|
output, output_lengths = encoder(x, lengths)
|
||||||
assert output.shape == (U, B, D)
|
assert output.shape == (U, B, D)
|
||||||
assert torch.equal(
|
assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0))
|
||||||
output_lengths, torch.clamp(lengths - R, min=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_emformer_encoder_infer():
|
def test_emformer_encoder_infer():
|
||||||
@ -236,8 +242,7 @@ def test_emformer_encoder_infer():
|
|||||||
x = torch.randn(U + R, B, D)
|
x = torch.randn(U + R, B, D)
|
||||||
lengths = torch.randint(1, U + R + 1, (B,))
|
lengths = torch.randint(1, U + R + 1, (B,))
|
||||||
lengths[0] = U + R
|
lengths[0] = U + R
|
||||||
output, output_lengths, states = \
|
output, output_lengths, states = encoder.infer(x, lengths, states)
|
||||||
encoder.infer(x, lengths, states)
|
|
||||||
assert output.shape == (U, B, D)
|
assert output.shape == (U, B, D)
|
||||||
assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0))
|
assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0))
|
||||||
assert len(states) == num_encoder_layers
|
assert len(states) == num_encoder_layers
|
||||||
@ -253,6 +258,7 @@ def test_emformer_encoder_infer():
|
|||||||
|
|
||||||
def test_emformer_forward():
|
def test_emformer_forward():
|
||||||
from emformer import Emformer
|
from emformer import Emformer
|
||||||
|
|
||||||
num_features = 80
|
num_features = 80
|
||||||
output_dim = 1000
|
output_dim = 1000
|
||||||
chunk_length = 8
|
chunk_length = 8
|
||||||
@ -281,12 +287,13 @@ def test_emformer_forward():
|
|||||||
assert logits.shape == (B, U // 4, output_dim)
|
assert logits.shape == (B, U // 4, output_dim)
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
output_lengths,
|
output_lengths,
|
||||||
torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0)
|
torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_emformer_infer():
|
def test_emformer_infer():
|
||||||
from emformer import Emformer
|
from emformer import Emformer
|
||||||
|
|
||||||
num_features = 80
|
num_features = 80
|
||||||
output_dim = 1000
|
output_dim = 1000
|
||||||
chunk_length = 8
|
chunk_length = 8
|
||||||
@ -317,12 +324,11 @@ def test_emformer_infer():
|
|||||||
x = torch.randn(B, U + R + 3, num_features)
|
x = torch.randn(B, U + R + 3, num_features)
|
||||||
x_lens = torch.randint(1, U + R + 3 + 1, (B,))
|
x_lens = torch.randint(1, U + R + 3 + 1, (B,))
|
||||||
x_lens[0] = U + R + 3
|
x_lens[0] = U + R + 3
|
||||||
logits, output_lengths, states = \
|
logits, output_lengths, states = model.infer(x, x_lens, states)
|
||||||
model.infer(x, x_lens, states)
|
|
||||||
assert logits.shape == (B, U // 4, output_dim)
|
assert logits.shape == (B, U // 4, output_dim)
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
output_lengths,
|
output_lengths,
|
||||||
torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0)
|
torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0),
|
||||||
)
|
)
|
||||||
assert len(states) == num_encoder_layers
|
assert len(states) == num_encoder_layers
|
||||||
for state in states:
|
for state in states:
|
||||||
@ -332,7 +338,7 @@ def test_emformer_infer():
|
|||||||
assert state[2].shape == (L // 4, B, D)
|
assert state[2].shape == (L // 4, B, D)
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
state[3],
|
state[3],
|
||||||
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3])
|
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user