mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
update test functions for emformer.
This commit is contained in:
parent
4130892971
commit
524f3aa015
@ -85,8 +85,6 @@ class EmformerAttention(nn.Module):
|
|||||||
Embedding dimension.
|
Embedding dimension.
|
||||||
nhead (int):
|
nhead (int):
|
||||||
Number of attention heads in each Emformer layer.
|
Number of attention heads in each Emformer layer.
|
||||||
dropout (float, optional):
|
|
||||||
Dropout probability. (Default: 0.0)
|
|
||||||
weight_init_gain (float or None, optional):
|
weight_init_gain (float or None, optional):
|
||||||
Scale factor to apply when initializing attention
|
Scale factor to apply when initializing attention
|
||||||
module parameters. (Default: ``None``)
|
module parameters. (Default: ``None``)
|
||||||
@ -100,7 +98,6 @@ class EmformerAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
nhead: int,
|
nhead: int,
|
||||||
dropout: float = 0.0,
|
|
||||||
weight_init_gain: Optional[float] = None,
|
weight_init_gain: Optional[float] = None,
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
@ -115,7 +112,6 @@ class EmformerAttention(nn.Module):
|
|||||||
|
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
self.nhead = nhead
|
self.nhead = nhead
|
||||||
self.dropout = dropout
|
|
||||||
self.tanh_on_mem = tanh_on_mem
|
self.tanh_on_mem = tanh_on_mem
|
||||||
self.negative_inf = negative_inf
|
self.negative_inf = negative_inf
|
||||||
|
|
||||||
@ -183,9 +179,7 @@ class EmformerAttention(nn.Module):
|
|||||||
attention_probs = nn.functional.softmax(
|
attention_probs = nn.functional.softmax(
|
||||||
attention_weights_float, dim=-1
|
attention_weights_float, dim=-1
|
||||||
).type_as(attention_weights)
|
).type_as(attention_weights)
|
||||||
# attention_probs = nn.functional.dropout(
|
|
||||||
# attention_probs, p=float(self.dropout), training=self.training
|
|
||||||
# )
|
|
||||||
return attention_probs
|
return attention_probs
|
||||||
|
|
||||||
def _forward_impl(
|
def _forward_impl(
|
||||||
@ -512,7 +506,6 @@ class EmformerLayer(nn.Module):
|
|||||||
self.attention = EmformerAttention(
|
self.attention = EmformerAttention(
|
||||||
embed_dim=d_model,
|
embed_dim=d_model,
|
||||||
nhead=nhead,
|
nhead=nhead,
|
||||||
dropout=dropout,
|
|
||||||
weight_init_gain=weight_init_gain,
|
weight_init_gain=weight_init_gain,
|
||||||
tanh_on_mem=tanh_on_mem,
|
tanh_on_mem=tanh_on_mem,
|
||||||
negative_inf=negative_inf,
|
negative_inf=negative_inf,
|
||||||
|
@ -362,8 +362,9 @@ def test_emformer_attention_forward_infer_consistency():
|
|||||||
left_context_length=L,
|
left_context_length=L,
|
||||||
right_context_length=R,
|
right_context_length=R,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
dropout=0.0,
|
dropout=0.1,
|
||||||
)
|
)
|
||||||
|
encoder.eval()
|
||||||
encoder_layer = encoder.emformer_layers[0]
|
encoder_layer = encoder.emformer_layers[0]
|
||||||
|
|
||||||
x = torch.randn(U + R, 1, D)
|
x = torch.randn(U + R, 1, D)
|
||||||
@ -415,12 +416,15 @@ def test_emformer_attention_forward_infer_consistency():
|
|||||||
chunk_memory,
|
chunk_memory,
|
||||||
state,
|
state,
|
||||||
)
|
)
|
||||||
infer_output_utterance = infer_output_right_context_utterance[
|
infer_output_chunk = infer_output_right_context_utterance[
|
||||||
chunk_right_context.size(0) : # noqa
|
chunk_right_context.size(0) : # noqa
|
||||||
]
|
]
|
||||||
print(
|
forward_output_chunk = forward_output_utterance[start_idx:end_idx]
|
||||||
infer_output_utterance
|
assert torch.allclose(
|
||||||
- forward_output_utterance[start_idx:end_idx]
|
infer_output_chunk,
|
||||||
|
forward_output_chunk,
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -444,8 +448,9 @@ def test_emformer_layer_forward_infer_consistency():
|
|||||||
left_context_length=L,
|
left_context_length=L,
|
||||||
right_context_length=R,
|
right_context_length=R,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
dropout=0.0,
|
dropout=0.1,
|
||||||
)
|
)
|
||||||
|
encoder.eval()
|
||||||
encoder_layer = encoder.emformer_layers[0]
|
encoder_layer = encoder.emformer_layers[0]
|
||||||
|
|
||||||
x = torch.randn(U + R, 1, D)
|
x = torch.randn(U + R, 1, D)
|
||||||
@ -485,7 +490,7 @@ def test_emformer_layer_forward_infer_consistency():
|
|||||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||||
)
|
)
|
||||||
(
|
(
|
||||||
infer_output_utterance,
|
infer_output_chunk,
|
||||||
infer_right_context,
|
infer_right_context,
|
||||||
infer_output_memory,
|
infer_output_memory,
|
||||||
state,
|
state,
|
||||||
@ -496,9 +501,12 @@ def test_emformer_layer_forward_infer_consistency():
|
|||||||
chunk_memory,
|
chunk_memory,
|
||||||
state,
|
state,
|
||||||
)
|
)
|
||||||
print(
|
forward_output_chunk = forward_output_utterance[start_idx:end_idx]
|
||||||
infer_output_utterance
|
assert torch.allclose(
|
||||||
- forward_output_utterance[start_idx:end_idx]
|
infer_output_chunk,
|
||||||
|
forward_output_chunk,
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -522,8 +530,9 @@ def test_emformer_encoder_forward_infer_consistency():
|
|||||||
left_context_length=L,
|
left_context_length=L,
|
||||||
right_context_length=R,
|
right_context_length=R,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
dropout=0.0,
|
dropout=0.1,
|
||||||
)
|
)
|
||||||
|
encoder.eval()
|
||||||
|
|
||||||
x = torch.randn(U + R, 1, D)
|
x = torch.randn(U + R, 1, D)
|
||||||
lengths = torch.tensor([U + R])
|
lengths = torch.tensor([U + R])
|
||||||
@ -537,23 +546,152 @@ def test_emformer_encoder_forward_infer_consistency():
|
|||||||
chunk = x[start_idx : end_idx + R] # noqa
|
chunk = x[start_idx : end_idx + R] # noqa
|
||||||
chunk_right_context = x[end_idx : end_idx + R] # noqa
|
chunk_right_context = x[end_idx : end_idx + R] # noqa
|
||||||
chunk_length = torch.tensor([chunk_length])
|
chunk_length = torch.tensor([chunk_length])
|
||||||
infer_output, infer_output_lengths, states = encoder.infer(
|
infer_output_chunk, infer_output_lengths, states = encoder.infer(
|
||||||
chunk,
|
chunk,
|
||||||
chunk_length,
|
chunk_length,
|
||||||
states,
|
states,
|
||||||
)
|
)
|
||||||
print(infer_output - forward_output[start_idx:end_idx])
|
forward_output_chunk = forward_output[start_idx:end_idx]
|
||||||
|
assert torch.allclose(
|
||||||
|
infer_output_chunk,
|
||||||
|
forward_output_chunk,
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_infer_batch_single_consistency():
|
||||||
|
"""Test consistency of cached states and output logits between single
|
||||||
|
utterance inference and batch inference."""
|
||||||
|
from emformer import Emformer
|
||||||
|
|
||||||
|
num_features = 80
|
||||||
|
output_dim = 1000
|
||||||
|
chunk_length = 8
|
||||||
|
num_chunks = 3
|
||||||
|
U = num_chunks * chunk_length
|
||||||
|
L, R = 128, 4
|
||||||
|
B, D = 2, 256
|
||||||
|
num_encoder_layers = 2
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
M = 3
|
||||||
|
else:
|
||||||
|
M = 0
|
||||||
|
model = Emformer(
|
||||||
|
num_features=num_features,
|
||||||
|
output_dim=output_dim,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
subsampling_factor=4,
|
||||||
|
d_model=D,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
left_context_length=L,
|
||||||
|
right_context_length=R,
|
||||||
|
max_memory_size=M,
|
||||||
|
vgg_frontend=False,
|
||||||
|
)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
def save_states(states):
|
||||||
|
saved_states = []
|
||||||
|
for layer_idx in range(len(states)):
|
||||||
|
layer_state = []
|
||||||
|
layer_state.append(states[layer_idx][0].clone()) # memory
|
||||||
|
layer_state.append(
|
||||||
|
states[layer_idx][1].clone()
|
||||||
|
) # left_context_key
|
||||||
|
layer_state.append(
|
||||||
|
states[layer_idx][2].clone()
|
||||||
|
) # left_context_val
|
||||||
|
layer_state.append(states[layer_idx][3].clone()) # past_length
|
||||||
|
saved_states.append(layer_state)
|
||||||
|
return saved_states
|
||||||
|
|
||||||
|
def assert_states_equal(saved_states, states, sample_idx):
|
||||||
|
for layer_idx in range(len(saved_states)):
|
||||||
|
# assert eqaul memory
|
||||||
|
assert torch.allclose(
|
||||||
|
states[layer_idx][0],
|
||||||
|
saved_states[layer_idx][0][
|
||||||
|
:, sample_idx : sample_idx + 1 # noqa
|
||||||
|
],
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=0.0,
|
||||||
|
)
|
||||||
|
# assert equal left_context_key
|
||||||
|
assert torch.allclose(
|
||||||
|
states[layer_idx][1],
|
||||||
|
saved_states[layer_idx][1][
|
||||||
|
:, sample_idx : sample_idx + 1 # noqa
|
||||||
|
],
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=0.0,
|
||||||
|
)
|
||||||
|
# assert equal left_context_val
|
||||||
|
assert torch.allclose(
|
||||||
|
states[layer_idx][2],
|
||||||
|
saved_states[layer_idx][2][
|
||||||
|
:, sample_idx : sample_idx + 1 # noqa
|
||||||
|
],
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=0.0,
|
||||||
|
)
|
||||||
|
# assert eqaul past_length
|
||||||
|
assert torch.equal(
|
||||||
|
states[layer_idx][3],
|
||||||
|
saved_states[layer_idx][3][
|
||||||
|
:, sample_idx : sample_idx + 1 # noqa
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.randn(B, U + R + 3, num_features)
|
||||||
|
batch_logits = []
|
||||||
|
batch_states = []
|
||||||
|
states = None
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
start_idx = chunk_idx * chunk_length
|
||||||
|
end_idx = start_idx + chunk_length
|
||||||
|
chunk = x[:, start_idx : end_idx + R + 3] # noqa
|
||||||
|
lengths = torch.tensor([chunk_length + R + 3]).expand(B)
|
||||||
|
logits, output_lengths, states = model.infer(chunk, lengths, states)
|
||||||
|
batch_logits.append(logits)
|
||||||
|
batch_states.append(save_states(states))
|
||||||
|
batch_logits = torch.cat(batch_logits, dim=1)
|
||||||
|
|
||||||
|
single_logits = []
|
||||||
|
for sample_idx in range(B):
|
||||||
|
sample = x[sample_idx : sample_idx + 1] # noqa
|
||||||
|
chunk_logits = []
|
||||||
|
states = None
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
start_idx = chunk_idx * chunk_length
|
||||||
|
end_idx = start_idx + chunk_length
|
||||||
|
chunk = sample[:, start_idx : end_idx + R + 3] # noqa
|
||||||
|
lengths = torch.tensor([chunk_length + R + 3])
|
||||||
|
logits, output_lengths, states = model.infer(
|
||||||
|
chunk, lengths, states
|
||||||
|
)
|
||||||
|
chunk_logits.append(logits)
|
||||||
|
|
||||||
|
assert_states_equal(batch_states[chunk_idx], states, sample_idx)
|
||||||
|
|
||||||
|
chunk_logits = torch.cat(chunk_logits, dim=1)
|
||||||
|
single_logits.append(chunk_logits)
|
||||||
|
single_logits = torch.cat(single_logits, dim=0)
|
||||||
|
|
||||||
|
assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# test_emformer_attention_forward()
|
test_emformer_attention_forward()
|
||||||
# test_emformer_attention_infer()
|
test_emformer_attention_infer()
|
||||||
# test_emformer_layer_forward()
|
test_emformer_layer_forward()
|
||||||
# test_emformer_layer_infer()
|
test_emformer_layer_infer()
|
||||||
# test_emformer_encoder_forward()
|
test_emformer_encoder_forward()
|
||||||
# test_emformer_encoder_infer()
|
test_emformer_encoder_infer()
|
||||||
# test_emformer_forward()
|
test_emformer_forward()
|
||||||
# test_emformer_infer()
|
test_emformer_infer()
|
||||||
# test_emformer_attention_forward_infer_consistency()
|
test_emformer_attention_forward_infer_consistency()
|
||||||
# test_emformer_layer_forward_infer_consistency()
|
test_emformer_layer_forward_infer_consistency()
|
||||||
test_emformer_encoder_forward_infer_consistency()
|
test_emformer_encoder_forward_infer_consistency()
|
||||||
|
test_emformer_infer_batch_single_consistency()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user