mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
delete other test functions
This commit is contained in:
parent
448a94c00a
commit
ce008aa2ca
@ -1,516 +1,6 @@
|
||||
import torch
|
||||
|
||||
|
||||
def test_emformer_attention_forward():
|
||||
from emformer import EmformerAttention
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
R = num_chunks * right_context_length
|
||||
attention = EmformerAttention(embed_dim=D, nhead=8)
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
S = num_chunks
|
||||
M = S - 1
|
||||
else:
|
||||
S, M = 0, 0
|
||||
|
||||
Q, KV = R + U + S, M + R + U
|
||||
utterance = torch.randn(U, B, D)
|
||||
lengths = torch.randint(1, U + 1, (B,))
|
||||
lengths[0] = U
|
||||
right_context = torch.randn(R, B, D)
|
||||
summary = torch.randn(S, B, D)
|
||||
memory = torch.randn(M, B, D)
|
||||
attention_mask = torch.rand(Q, KV) >= 0.5
|
||||
|
||||
output_right_context_utterance, output_memory = attention(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
summary,
|
||||
memory,
|
||||
attention_mask,
|
||||
)
|
||||
assert output_right_context_utterance.shape == (R + U, B, D)
|
||||
assert output_memory.shape == (M, B, D)
|
||||
|
||||
|
||||
def test_emformer_attention_infer():
|
||||
from emformer import EmformerAttention
|
||||
|
||||
B, D = 2, 256
|
||||
U = 4
|
||||
R = 2
|
||||
L = 3
|
||||
attention = EmformerAttention(embed_dim=D, nhead=8)
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
S, M = 1, 3
|
||||
else:
|
||||
S, M = 0, 0
|
||||
|
||||
utterance = torch.randn(U, B, D)
|
||||
lengths = torch.randint(1, U + 1, (B,))
|
||||
lengths[0] = U
|
||||
right_context = torch.randn(R, B, D)
|
||||
summary = torch.randn(S, B, D)
|
||||
memory = torch.randn(M, B, D)
|
||||
left_context_key = torch.randn(L, B, D)
|
||||
left_context_val = torch.randn(L, B, D)
|
||||
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
next_key,
|
||||
next_val,
|
||||
) = attention.infer(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
summary,
|
||||
memory,
|
||||
left_context_key,
|
||||
left_context_val,
|
||||
)
|
||||
assert output_right_context_utterance.shape == (R + U, B, D)
|
||||
assert output_memory.shape == (S, B, D)
|
||||
assert next_key.shape == (L + U, B, D)
|
||||
assert next_val.shape == (L + U, B, D)
|
||||
|
||||
|
||||
def test_convolution_module_forward():
|
||||
from emformer import ConvolutionModule
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
R = num_chunks * right_context_length
|
||||
kernel_size = 31
|
||||
conv_module = ConvolutionModule(
|
||||
chunk_length,
|
||||
right_context_length,
|
||||
D,
|
||||
kernel_size,
|
||||
)
|
||||
|
||||
utterance = torch.randn(U, B, D)
|
||||
right_context = torch.randn(R, B, D)
|
||||
cache = torch.randn(B, D, kernel_size - 1)
|
||||
|
||||
utterance, right_context, new_cache = conv_module(
|
||||
utterance, right_context, cache
|
||||
)
|
||||
assert utterance.shape == (U, B, D)
|
||||
assert right_context.shape == (R, B, D)
|
||||
assert new_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
def test_convolution_module_infer():
|
||||
from emformer import ConvolutionModule
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
num_chunks = 1
|
||||
U = num_chunks * chunk_length
|
||||
R = num_chunks * right_context_length
|
||||
kernel_size = 31
|
||||
conv_module = ConvolutionModule(
|
||||
chunk_length,
|
||||
right_context_length,
|
||||
D,
|
||||
kernel_size,
|
||||
)
|
||||
|
||||
utterance = torch.randn(U, B, D)
|
||||
right_context = torch.randn(R, B, D)
|
||||
cache = torch.randn(B, D, kernel_size - 1)
|
||||
|
||||
utterance, right_context, new_cache = conv_module.infer(
|
||||
utterance, right_context, cache
|
||||
)
|
||||
assert utterance.shape == (U, B, D)
|
||||
assert right_context.shape == (R, B, D)
|
||||
assert new_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
def test_emformer_encoder_layer_forward():
|
||||
from emformer import EmformerEncoderLayer
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 8
|
||||
right_context_length = 2
|
||||
left_context_length = 8
|
||||
kernel_size = 31
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
R = num_chunks * right_context_length
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
S = num_chunks
|
||||
M = S - 1
|
||||
else:
|
||||
S, M = 0, 0
|
||||
|
||||
layer = EmformerEncoderLayer(
|
||||
d_model=D,
|
||||
nhead=8,
|
||||
dim_feedforward=1024,
|
||||
chunk_length=chunk_length,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
)
|
||||
|
||||
Q, KV = R + U + S, M + R + U
|
||||
utterance = torch.randn(U, B, D)
|
||||
lengths = torch.randint(1, U + 1, (B,))
|
||||
lengths[0] = U
|
||||
right_context = torch.randn(R, B, D)
|
||||
memory = torch.randn(M, B, D)
|
||||
attention_mask = torch.rand(Q, KV) >= 0.5
|
||||
|
||||
output_utterance, output_right_context, output_memory = layer(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
memory,
|
||||
attention_mask,
|
||||
)
|
||||
assert output_utterance.shape == (U, B, D)
|
||||
assert output_right_context.shape == (R, B, D)
|
||||
assert output_memory.shape == (M, B, D)
|
||||
|
||||
|
||||
def test_emformer_encoder_layer_infer():
|
||||
from emformer import EmformerEncoderLayer
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 8
|
||||
right_context_length = 2
|
||||
left_context_length = 8
|
||||
kernel_size = 31
|
||||
num_chunks = 1
|
||||
U = num_chunks * chunk_length
|
||||
R = num_chunks * right_context_length
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
M = 3
|
||||
else:
|
||||
M = 0
|
||||
|
||||
layer = EmformerEncoderLayer(
|
||||
d_model=D,
|
||||
nhead=8,
|
||||
dim_feedforward=1024,
|
||||
chunk_length=chunk_length,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
)
|
||||
|
||||
utterance = torch.randn(U, B, D)
|
||||
lengths = torch.randint(1, U + 1, (B,))
|
||||
lengths[0] = U
|
||||
right_context = torch.randn(R, B, D)
|
||||
memory = torch.randn(M, B, D)
|
||||
state = None
|
||||
conv_cache = None
|
||||
(
|
||||
output_utterance,
|
||||
output_right_context,
|
||||
output_memory,
|
||||
output_state,
|
||||
conv_cache,
|
||||
) = layer.infer(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
memory,
|
||||
state,
|
||||
conv_cache,
|
||||
)
|
||||
assert output_utterance.shape == (U, B, D)
|
||||
assert output_right_context.shape == (R, B, D)
|
||||
if use_memory:
|
||||
assert output_memory.shape == (1, B, D)
|
||||
else:
|
||||
assert output_memory.shape == (0, B, D)
|
||||
assert len(output_state) == 4
|
||||
assert output_state[0].shape == (M, B, D)
|
||||
assert output_state[1].shape == (left_context_length, B, D)
|
||||
assert output_state[2].shape == (left_context_length, B, D)
|
||||
assert output_state[3].shape == (1, B)
|
||||
assert conv_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
def test_emformer_encoder_forward():
|
||||
from emformer import EmformerEncoder
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
left_context_length = 2
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
kernel_size = 31
|
||||
num_encoder_layers = 2
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
S = num_chunks
|
||||
M = S - 1
|
||||
else:
|
||||
S, M = 0, 0
|
||||
|
||||
encoder = EmformerEncoder(
|
||||
chunk_length=chunk_length,
|
||||
d_model=D,
|
||||
dim_feedforward=1024,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
)
|
||||
|
||||
x = torch.randn(U + right_context_length, B, D)
|
||||
lengths = torch.randint(1, U + right_context_length + 1, (B,))
|
||||
lengths[0] = U + right_context_length
|
||||
|
||||
output, output_lengths = encoder(x, lengths)
|
||||
assert output.shape == (U, B, D)
|
||||
assert torch.equal(
|
||||
output_lengths, torch.clamp(lengths - right_context_length, min=0)
|
||||
)
|
||||
|
||||
|
||||
def test_emformer_encoder_infer():
|
||||
from emformer import EmformerEncoder
|
||||
|
||||
B, D = 2, 256
|
||||
num_encoder_layers = 2
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
left_context_length = 2
|
||||
num_chunks = 3
|
||||
kernel_size = 31
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
M = 3
|
||||
else:
|
||||
M = 0
|
||||
|
||||
encoder = EmformerEncoder(
|
||||
chunk_length=chunk_length,
|
||||
d_model=D,
|
||||
dim_feedforward=1024,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
)
|
||||
|
||||
states = None
|
||||
conv_caches = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
x = torch.randn(chunk_length + right_context_length, B, D)
|
||||
lengths = torch.randint(
|
||||
1, chunk_length + right_context_length + 1, (B,)
|
||||
)
|
||||
lengths[0] = chunk_length + right_context_length
|
||||
output, output_lengths, states, conv_caches = encoder.infer(
|
||||
x, lengths, states, conv_caches
|
||||
)
|
||||
assert output.shape == (chunk_length, B, D)
|
||||
assert torch.equal(
|
||||
output_lengths,
|
||||
torch.clamp(lengths - right_context_length, min=0),
|
||||
)
|
||||
assert len(states) == num_encoder_layers
|
||||
for state in states:
|
||||
assert len(state) == 4
|
||||
assert state[0].shape == (M, B, D)
|
||||
assert state[1].shape == (left_context_length, B, D)
|
||||
assert state[2].shape == (left_context_length, B, D)
|
||||
assert torch.equal(
|
||||
state[3],
|
||||
(chunk_idx + 1) * chunk_length * torch.ones_like(state[3]),
|
||||
)
|
||||
for conv_cache in conv_caches:
|
||||
assert conv_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
def test_emformer_encoder_forward_infer_consistency():
|
||||
from emformer import EmformerEncoder
|
||||
|
||||
chunk_length = 4
|
||||
num_chunks = 3
|
||||
U = chunk_length * num_chunks
|
||||
left_context_length, right_context_length = 1, 2
|
||||
D = 256
|
||||
num_encoder_layers = 3
|
||||
kernel_size = 31
|
||||
memory_sizes = [0, 3]
|
||||
|
||||
for M in memory_sizes:
|
||||
encoder = EmformerEncoder(
|
||||
chunk_length=chunk_length,
|
||||
d_model=D,
|
||||
dim_feedforward=1024,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
)
|
||||
encoder.eval()
|
||||
|
||||
x = torch.randn(U + right_context_length, 1, D)
|
||||
lengths = torch.tensor([U + right_context_length])
|
||||
|
||||
# training mode with full utterance
|
||||
forward_output, forward_output_lengths = encoder(x, lengths)
|
||||
|
||||
# streaming inference mode with individual chunks
|
||||
states = None
|
||||
conv_caches = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
start_idx = chunk_idx * chunk_length
|
||||
end_idx = start_idx + chunk_length
|
||||
chunk = x[start_idx : end_idx + right_context_length] # noqa
|
||||
(
|
||||
infer_output_chunk,
|
||||
infer_output_lengths,
|
||||
states,
|
||||
conv_caches,
|
||||
) = encoder.infer(
|
||||
chunk, torch.tensor([chunk_length]), states, conv_caches
|
||||
)
|
||||
forward_output_chunk = forward_output[start_idx:end_idx]
|
||||
assert torch.allclose(
|
||||
infer_output_chunk,
|
||||
forward_output_chunk,
|
||||
atol=1e-4,
|
||||
rtol=0.0,
|
||||
), (
|
||||
infer_output_chunk - forward_output_chunk
|
||||
)
|
||||
|
||||
|
||||
def test_emformer_forward():
|
||||
from emformer import Emformer
|
||||
|
||||
num_features = 80
|
||||
chunk_length = 16
|
||||
right_context_length = 8
|
||||
left_context_length = 8
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
B, D = 2, 256
|
||||
kernel_size = 31
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
M = 3
|
||||
else:
|
||||
M = 0
|
||||
model = Emformer(
|
||||
num_features=num_features,
|
||||
chunk_length=chunk_length,
|
||||
subsampling_factor=4,
|
||||
d_model=D,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
)
|
||||
x = torch.randn(B, U + right_context_length + 3, num_features)
|
||||
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
|
||||
x_lens[0] = U + right_context_length + 3
|
||||
output, output_lengths = model(x, x_lens)
|
||||
assert output.shape == (B, U // 4, D)
|
||||
assert torch.equal(
|
||||
output_lengths,
|
||||
torch.clamp(
|
||||
((x_lens - 1) // 2 - 1) // 2 - right_context_length // 4, min=0
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_emformer_infer():
|
||||
from emformer import Emformer
|
||||
|
||||
num_features = 80
|
||||
chunk_length = 8
|
||||
U = chunk_length
|
||||
left_context_length, right_context_length = 128, 4
|
||||
B, D = 2, 256
|
||||
num_chunks = 3
|
||||
num_encoder_layers = 2
|
||||
kernel_size = 31
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
M = 3
|
||||
else:
|
||||
M = 0
|
||||
model = Emformer(
|
||||
num_features=num_features,
|
||||
chunk_length=chunk_length,
|
||||
subsampling_factor=4,
|
||||
d_model=D,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
)
|
||||
states = None
|
||||
conv_caches = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
x = torch.randn(B, U + right_context_length + 3, num_features)
|
||||
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
|
||||
x_lens[0] = U + right_context_length + 3
|
||||
output, output_lengths, states, conv_caches = model.infer(
|
||||
x, x_lens, states, conv_caches
|
||||
)
|
||||
assert output.shape == (B, U // 4, D)
|
||||
assert torch.equal(
|
||||
output_lengths,
|
||||
torch.clamp(
|
||||
((x_lens - 1) // 2 - 1) // 2 - right_context_length // 4,
|
||||
min=0,
|
||||
),
|
||||
)
|
||||
assert len(states) == num_encoder_layers
|
||||
for state in states:
|
||||
assert len(state) == 4
|
||||
assert state[0].shape == (M, B, D)
|
||||
assert state[1].shape == (left_context_length // 4, B, D)
|
||||
assert state[2].shape == (left_context_length // 4, B, D)
|
||||
assert torch.equal(
|
||||
state[3],
|
||||
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]),
|
||||
)
|
||||
for conv_cache in conv_caches:
|
||||
assert conv_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
def test_state_stack_unstack():
|
||||
from emformer import Emformer, stack_states, unstack_states
|
||||
|
||||
@ -571,7 +61,7 @@ def test_state_stack_unstack():
|
||||
|
||||
|
||||
def test_torchscript_consistency_infer():
|
||||
r"""Verify that scripting Emformer does not change the behavior of method `infer`."""
|
||||
r"""Verify that scripting Emformer does not change the behavior of method `infer`.""" # noqa
|
||||
from emformer import Emformer
|
||||
|
||||
num_features = 80
|
||||
@ -628,16 +118,5 @@ def test_torchscript_consistency_infer():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_emformer_attention_forward()
|
||||
# test_emformer_attention_infer()
|
||||
# test_convolution_module_forward()
|
||||
# test_convolution_module_infer()
|
||||
# test_emformer_encoder_layer_forward()
|
||||
# test_emformer_encoder_layer_infer()
|
||||
# test_emformer_encoder_forward()
|
||||
# test_emformer_encoder_infer()
|
||||
# test_emformer_encoder_forward_infer_consistency()
|
||||
# test_emformer_forward()
|
||||
# test_emformer_infer()
|
||||
# test_state_stack_unstack()
|
||||
test_state_stack_unstack()
|
||||
test_torchscript_consistency_infer()
|
||||
|
Loading…
x
Reference in New Issue
Block a user