delete other test functions

This commit is contained in:
yaozengwei 2022-06-11 22:09:53 +08:00
parent 448a94c00a
commit ce008aa2ca

View File

@ -1,516 +1,6 @@
import torch
def test_emformer_attention_forward():
from emformer import EmformerAttention
B, D = 2, 256
chunk_length = 4
right_context_length = 2
num_chunks = 3
U = num_chunks * chunk_length
R = num_chunks * right_context_length
attention = EmformerAttention(embed_dim=D, nhead=8)
for use_memory in [True, False]:
if use_memory:
S = num_chunks
M = S - 1
else:
S, M = 0, 0
Q, KV = R + U + S, M + R + U
utterance = torch.randn(U, B, D)
lengths = torch.randint(1, U + 1, (B,))
lengths[0] = U
right_context = torch.randn(R, B, D)
summary = torch.randn(S, B, D)
memory = torch.randn(M, B, D)
attention_mask = torch.rand(Q, KV) >= 0.5
output_right_context_utterance, output_memory = attention(
utterance,
lengths,
right_context,
summary,
memory,
attention_mask,
)
assert output_right_context_utterance.shape == (R + U, B, D)
assert output_memory.shape == (M, B, D)
def test_emformer_attention_infer():
from emformer import EmformerAttention
B, D = 2, 256
U = 4
R = 2
L = 3
attention = EmformerAttention(embed_dim=D, nhead=8)
for use_memory in [True, False]:
if use_memory:
S, M = 1, 3
else:
S, M = 0, 0
utterance = torch.randn(U, B, D)
lengths = torch.randint(1, U + 1, (B,))
lengths[0] = U
right_context = torch.randn(R, B, D)
summary = torch.randn(S, B, D)
memory = torch.randn(M, B, D)
left_context_key = torch.randn(L, B, D)
left_context_val = torch.randn(L, B, D)
(
output_right_context_utterance,
output_memory,
next_key,
next_val,
) = attention.infer(
utterance,
lengths,
right_context,
summary,
memory,
left_context_key,
left_context_val,
)
assert output_right_context_utterance.shape == (R + U, B, D)
assert output_memory.shape == (S, B, D)
assert next_key.shape == (L + U, B, D)
assert next_val.shape == (L + U, B, D)
def test_convolution_module_forward():
from emformer import ConvolutionModule
B, D = 2, 256
chunk_length = 4
right_context_length = 2
num_chunks = 3
U = num_chunks * chunk_length
R = num_chunks * right_context_length
kernel_size = 31
conv_module = ConvolutionModule(
chunk_length,
right_context_length,
D,
kernel_size,
)
utterance = torch.randn(U, B, D)
right_context = torch.randn(R, B, D)
cache = torch.randn(B, D, kernel_size - 1)
utterance, right_context, new_cache = conv_module(
utterance, right_context, cache
)
assert utterance.shape == (U, B, D)
assert right_context.shape == (R, B, D)
assert new_cache.shape == (B, D, kernel_size - 1)
def test_convolution_module_infer():
from emformer import ConvolutionModule
B, D = 2, 256
chunk_length = 4
right_context_length = 2
num_chunks = 1
U = num_chunks * chunk_length
R = num_chunks * right_context_length
kernel_size = 31
conv_module = ConvolutionModule(
chunk_length,
right_context_length,
D,
kernel_size,
)
utterance = torch.randn(U, B, D)
right_context = torch.randn(R, B, D)
cache = torch.randn(B, D, kernel_size - 1)
utterance, right_context, new_cache = conv_module.infer(
utterance, right_context, cache
)
assert utterance.shape == (U, B, D)
assert right_context.shape == (R, B, D)
assert new_cache.shape == (B, D, kernel_size - 1)
def test_emformer_encoder_layer_forward():
from emformer import EmformerEncoderLayer
B, D = 2, 256
chunk_length = 8
right_context_length = 2
left_context_length = 8
kernel_size = 31
num_chunks = 3
U = num_chunks * chunk_length
R = num_chunks * right_context_length
for use_memory in [True, False]:
if use_memory:
S = num_chunks
M = S - 1
else:
S, M = 0, 0
layer = EmformerEncoderLayer(
d_model=D,
nhead=8,
dim_feedforward=1024,
chunk_length=chunk_length,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
)
Q, KV = R + U + S, M + R + U
utterance = torch.randn(U, B, D)
lengths = torch.randint(1, U + 1, (B,))
lengths[0] = U
right_context = torch.randn(R, B, D)
memory = torch.randn(M, B, D)
attention_mask = torch.rand(Q, KV) >= 0.5
output_utterance, output_right_context, output_memory = layer(
utterance,
lengths,
right_context,
memory,
attention_mask,
)
assert output_utterance.shape == (U, B, D)
assert output_right_context.shape == (R, B, D)
assert output_memory.shape == (M, B, D)
def test_emformer_encoder_layer_infer():
from emformer import EmformerEncoderLayer
B, D = 2, 256
chunk_length = 8
right_context_length = 2
left_context_length = 8
kernel_size = 31
num_chunks = 1
U = num_chunks * chunk_length
R = num_chunks * right_context_length
for use_memory in [True, False]:
if use_memory:
M = 3
else:
M = 0
layer = EmformerEncoderLayer(
d_model=D,
nhead=8,
dim_feedforward=1024,
chunk_length=chunk_length,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
)
utterance = torch.randn(U, B, D)
lengths = torch.randint(1, U + 1, (B,))
lengths[0] = U
right_context = torch.randn(R, B, D)
memory = torch.randn(M, B, D)
state = None
conv_cache = None
(
output_utterance,
output_right_context,
output_memory,
output_state,
conv_cache,
) = layer.infer(
utterance,
lengths,
right_context,
memory,
state,
conv_cache,
)
assert output_utterance.shape == (U, B, D)
assert output_right_context.shape == (R, B, D)
if use_memory:
assert output_memory.shape == (1, B, D)
else:
assert output_memory.shape == (0, B, D)
assert len(output_state) == 4
assert output_state[0].shape == (M, B, D)
assert output_state[1].shape == (left_context_length, B, D)
assert output_state[2].shape == (left_context_length, B, D)
assert output_state[3].shape == (1, B)
assert conv_cache.shape == (B, D, kernel_size - 1)
def test_emformer_encoder_forward():
from emformer import EmformerEncoder
B, D = 2, 256
chunk_length = 4
right_context_length = 2
left_context_length = 2
num_chunks = 3
U = num_chunks * chunk_length
kernel_size = 31
num_encoder_layers = 2
for use_memory in [True, False]:
if use_memory:
S = num_chunks
M = S - 1
else:
S, M = 0, 0
encoder = EmformerEncoder(
chunk_length=chunk_length,
d_model=D,
dim_feedforward=1024,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
)
x = torch.randn(U + right_context_length, B, D)
lengths = torch.randint(1, U + right_context_length + 1, (B,))
lengths[0] = U + right_context_length
output, output_lengths = encoder(x, lengths)
assert output.shape == (U, B, D)
assert torch.equal(
output_lengths, torch.clamp(lengths - right_context_length, min=0)
)
def test_emformer_encoder_infer():
from emformer import EmformerEncoder
B, D = 2, 256
num_encoder_layers = 2
chunk_length = 4
right_context_length = 2
left_context_length = 2
num_chunks = 3
kernel_size = 31
for use_memory in [True, False]:
if use_memory:
M = 3
else:
M = 0
encoder = EmformerEncoder(
chunk_length=chunk_length,
d_model=D,
dim_feedforward=1024,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
)
states = None
conv_caches = None
for chunk_idx in range(num_chunks):
x = torch.randn(chunk_length + right_context_length, B, D)
lengths = torch.randint(
1, chunk_length + right_context_length + 1, (B,)
)
lengths[0] = chunk_length + right_context_length
output, output_lengths, states, conv_caches = encoder.infer(
x, lengths, states, conv_caches
)
assert output.shape == (chunk_length, B, D)
assert torch.equal(
output_lengths,
torch.clamp(lengths - right_context_length, min=0),
)
assert len(states) == num_encoder_layers
for state in states:
assert len(state) == 4
assert state[0].shape == (M, B, D)
assert state[1].shape == (left_context_length, B, D)
assert state[2].shape == (left_context_length, B, D)
assert torch.equal(
state[3],
(chunk_idx + 1) * chunk_length * torch.ones_like(state[3]),
)
for conv_cache in conv_caches:
assert conv_cache.shape == (B, D, kernel_size - 1)
def test_emformer_encoder_forward_infer_consistency():
from emformer import EmformerEncoder
chunk_length = 4
num_chunks = 3
U = chunk_length * num_chunks
left_context_length, right_context_length = 1, 2
D = 256
num_encoder_layers = 3
kernel_size = 31
memory_sizes = [0, 3]
for M in memory_sizes:
encoder = EmformerEncoder(
chunk_length=chunk_length,
d_model=D,
dim_feedforward=1024,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
)
encoder.eval()
x = torch.randn(U + right_context_length, 1, D)
lengths = torch.tensor([U + right_context_length])
# training mode with full utterance
forward_output, forward_output_lengths = encoder(x, lengths)
# streaming inference mode with individual chunks
states = None
conv_caches = None
for chunk_idx in range(num_chunks):
start_idx = chunk_idx * chunk_length
end_idx = start_idx + chunk_length
chunk = x[start_idx : end_idx + right_context_length] # noqa
(
infer_output_chunk,
infer_output_lengths,
states,
conv_caches,
) = encoder.infer(
chunk, torch.tensor([chunk_length]), states, conv_caches
)
forward_output_chunk = forward_output[start_idx:end_idx]
assert torch.allclose(
infer_output_chunk,
forward_output_chunk,
atol=1e-4,
rtol=0.0,
), (
infer_output_chunk - forward_output_chunk
)
def test_emformer_forward():
from emformer import Emformer
num_features = 80
chunk_length = 16
right_context_length = 8
left_context_length = 8
num_chunks = 3
U = num_chunks * chunk_length
B, D = 2, 256
kernel_size = 31
for use_memory in [True, False]:
if use_memory:
M = 3
else:
M = 0
model = Emformer(
num_features=num_features,
chunk_length=chunk_length,
subsampling_factor=4,
d_model=D,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
)
x = torch.randn(B, U + right_context_length + 3, num_features)
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
x_lens[0] = U + right_context_length + 3
output, output_lengths = model(x, x_lens)
assert output.shape == (B, U // 4, D)
assert torch.equal(
output_lengths,
torch.clamp(
((x_lens - 1) // 2 - 1) // 2 - right_context_length // 4, min=0
),
)
def test_emformer_infer():
from emformer import Emformer
num_features = 80
chunk_length = 8
U = chunk_length
left_context_length, right_context_length = 128, 4
B, D = 2, 256
num_chunks = 3
num_encoder_layers = 2
kernel_size = 31
for use_memory in [True, False]:
if use_memory:
M = 3
else:
M = 0
model = Emformer(
num_features=num_features,
chunk_length=chunk_length,
subsampling_factor=4,
d_model=D,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
)
states = None
conv_caches = None
for chunk_idx in range(num_chunks):
x = torch.randn(B, U + right_context_length + 3, num_features)
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
x_lens[0] = U + right_context_length + 3
output, output_lengths, states, conv_caches = model.infer(
x, x_lens, states, conv_caches
)
assert output.shape == (B, U // 4, D)
assert torch.equal(
output_lengths,
torch.clamp(
((x_lens - 1) // 2 - 1) // 2 - right_context_length // 4,
min=0,
),
)
assert len(states) == num_encoder_layers
for state in states:
assert len(state) == 4
assert state[0].shape == (M, B, D)
assert state[1].shape == (left_context_length // 4, B, D)
assert state[2].shape == (left_context_length // 4, B, D)
assert torch.equal(
state[3],
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]),
)
for conv_cache in conv_caches:
assert conv_cache.shape == (B, D, kernel_size - 1)
def test_state_stack_unstack():
from emformer import Emformer, stack_states, unstack_states
@ -571,7 +61,7 @@ def test_state_stack_unstack():
def test_torchscript_consistency_infer():
r"""Verify that scripting Emformer does not change the behavior of method `infer`."""
r"""Verify that scripting Emformer does not change the behavior of method `infer`.""" # noqa
from emformer import Emformer
num_features = 80
@ -628,16 +118,5 @@ def test_torchscript_consistency_infer():
if __name__ == "__main__":
# test_emformer_attention_forward()
# test_emformer_attention_infer()
# test_convolution_module_forward()
# test_convolution_module_infer()
# test_emformer_encoder_layer_forward()
# test_emformer_encoder_layer_infer()
# test_emformer_encoder_forward()
# test_emformer_encoder_infer()
# test_emformer_encoder_forward_infer_consistency()
# test_emformer_forward()
# test_emformer_infer()
# test_state_stack_unstack()
test_state_stack_unstack()
test_torchscript_consistency_infer()