mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
update test functions for conv_emformer_transducer/emformer.py
This commit is contained in:
parent
32420cc3e4
commit
df7919f4bf
@ -14,8 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
# It is modified based on
|
# It is modified based on https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py. # noqa
|
||||||
# https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py.
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
@ -56,8 +55,6 @@ class EmformerAttention(nn.Module):
|
|||||||
Embedding dimension.
|
Embedding dimension.
|
||||||
nhead (int):
|
nhead (int):
|
||||||
Number of attention heads in each Emformer layer.
|
Number of attention heads in each Emformer layer.
|
||||||
dropout (float, optional):
|
|
||||||
Dropout probability. (Default: 0.0)
|
|
||||||
tanh_on_mem (bool, optional):
|
tanh_on_mem (bool, optional):
|
||||||
If ``True``, applies tanh to memory elements. (Default: ``False``)
|
If ``True``, applies tanh to memory elements. (Default: ``False``)
|
||||||
negative_inf (float, optional):
|
negative_inf (float, optional):
|
||||||
@ -68,7 +65,6 @@ class EmformerAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
nhead: int,
|
nhead: int,
|
||||||
dropout: float = 0.0,
|
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
):
|
):
|
||||||
@ -82,7 +78,6 @@ class EmformerAttention(nn.Module):
|
|||||||
|
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
self.nhead = nhead
|
self.nhead = nhead
|
||||||
self.dropout = dropout
|
|
||||||
self.tanh_on_mem = tanh_on_mem
|
self.tanh_on_mem = tanh_on_mem
|
||||||
self.negative_inf = negative_inf
|
self.negative_inf = negative_inf
|
||||||
|
|
||||||
@ -154,9 +149,7 @@ class EmformerAttention(nn.Module):
|
|||||||
attention_probs = nn.functional.softmax(
|
attention_probs = nn.functional.softmax(
|
||||||
attention_weights_float, dim=-1
|
attention_weights_float, dim=-1
|
||||||
).type_as(attention_weights)
|
).type_as(attention_weights)
|
||||||
attention_probs = nn.functional.dropout(
|
|
||||||
attention_probs, p=float(self.dropout), training=self.training
|
|
||||||
)
|
|
||||||
return attention_probs
|
return attention_probs
|
||||||
|
|
||||||
def _forward_impl(
|
def _forward_impl(
|
||||||
@ -481,7 +474,6 @@ class EmformerLayer(nn.Module):
|
|||||||
self.attention = EmformerAttention(
|
self.attention = EmformerAttention(
|
||||||
embed_dim=d_model,
|
embed_dim=d_model,
|
||||||
nhead=nhead,
|
nhead=nhead,
|
||||||
dropout=0.0,
|
|
||||||
tanh_on_mem=tanh_on_mem,
|
tanh_on_mem=tanh_on_mem,
|
||||||
negative_inf=negative_inf,
|
negative_inf=negative_inf,
|
||||||
)
|
)
|
||||||
|
@ -366,6 +366,216 @@ def test_emformer_infer():
|
|||||||
assert conv_cache.shape == (B, D, K - 1)
|
assert conv_cache.shape == (B, D, K - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_encoder_layer_forward_infer_consistency():
|
||||||
|
from emformer import EmformerEncoder
|
||||||
|
|
||||||
|
chunk_length = 4
|
||||||
|
num_chunks = 3
|
||||||
|
U = chunk_length * num_chunks
|
||||||
|
L, R = 1, 2
|
||||||
|
D = 256
|
||||||
|
num_encoder_layers = 1
|
||||||
|
memory_sizes = [0, 3]
|
||||||
|
K = 3
|
||||||
|
|
||||||
|
for M in memory_sizes:
|
||||||
|
encoder = EmformerEncoder(
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
d_model=D,
|
||||||
|
dim_feedforward=1024,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
left_context_length=L,
|
||||||
|
right_context_length=R,
|
||||||
|
max_memory_size=M,
|
||||||
|
dropout=0.1,
|
||||||
|
cnn_module_kernel=K,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
encoder.eval()
|
||||||
|
encoder_layer = encoder.emformer_layers[0]
|
||||||
|
|
||||||
|
x = torch.randn(U + R, 1, D)
|
||||||
|
lengths = torch.tensor([U])
|
||||||
|
right_context = encoder._gen_right_context(x)
|
||||||
|
utterance = x[: x.size(0) - R]
|
||||||
|
attention_mask = encoder._gen_attention_mask(utterance)
|
||||||
|
memory = (
|
||||||
|
encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
|
||||||
|
:-1
|
||||||
|
]
|
||||||
|
if encoder.use_memory
|
||||||
|
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||||
|
)
|
||||||
|
(
|
||||||
|
forward_output_utterance,
|
||||||
|
forward_output_right_context,
|
||||||
|
forward_output_memory,
|
||||||
|
) = encoder_layer(
|
||||||
|
utterance,
|
||||||
|
lengths,
|
||||||
|
right_context,
|
||||||
|
memory,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
state = None
|
||||||
|
conv_cache = None
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
start_idx = chunk_idx * chunk_length
|
||||||
|
end_idx = start_idx + chunk_length
|
||||||
|
chunk = x[start_idx:end_idx]
|
||||||
|
chunk_right_context = x[end_idx : end_idx + R] # noqa
|
||||||
|
chunk_length = torch.tensor([chunk_length])
|
||||||
|
chunk_memory = (
|
||||||
|
encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1)
|
||||||
|
if encoder.use_memory
|
||||||
|
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||||
|
)
|
||||||
|
(
|
||||||
|
infer_output_chunk,
|
||||||
|
infer_right_context,
|
||||||
|
infer_output_memory,
|
||||||
|
state,
|
||||||
|
conv_cache,
|
||||||
|
) = encoder_layer.infer(
|
||||||
|
chunk,
|
||||||
|
chunk_length,
|
||||||
|
chunk_right_context,
|
||||||
|
chunk_memory,
|
||||||
|
state,
|
||||||
|
conv_cache,
|
||||||
|
)
|
||||||
|
forward_output_chunk = forward_output_utterance[start_idx:end_idx]
|
||||||
|
assert torch.allclose(
|
||||||
|
infer_output_chunk,
|
||||||
|
forward_output_chunk,
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_encoder_forward_infer_consistency():
|
||||||
|
from emformer import EmformerEncoder
|
||||||
|
|
||||||
|
chunk_length = 4
|
||||||
|
num_chunks = 3
|
||||||
|
U = chunk_length * num_chunks
|
||||||
|
L, R = 1, 2
|
||||||
|
D = 256
|
||||||
|
num_encoder_layers = 3
|
||||||
|
K = 3
|
||||||
|
memory_sizes = [0, 3]
|
||||||
|
|
||||||
|
for M in memory_sizes:
|
||||||
|
encoder = EmformerEncoder(
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
d_model=D,
|
||||||
|
dim_feedforward=1024,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
left_context_length=L,
|
||||||
|
right_context_length=R,
|
||||||
|
max_memory_size=M,
|
||||||
|
dropout=0.1,
|
||||||
|
cnn_module_kernel=K,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
encoder.eval()
|
||||||
|
|
||||||
|
x = torch.randn(U + R, 1, D)
|
||||||
|
lengths = torch.tensor([U + R])
|
||||||
|
|
||||||
|
forward_output, forward_output_lengths = encoder(x, lengths)
|
||||||
|
|
||||||
|
states = None
|
||||||
|
conv_caches = None
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
start_idx = chunk_idx * chunk_length
|
||||||
|
end_idx = start_idx + chunk_length
|
||||||
|
chunk = x[start_idx : end_idx + R] # noqa
|
||||||
|
chunk_right_context = x[end_idx : end_idx + R] # noqa
|
||||||
|
chunk_length = torch.tensor([chunk_length])
|
||||||
|
(
|
||||||
|
infer_output_chunk,
|
||||||
|
infer_output_lengths,
|
||||||
|
states,
|
||||||
|
conv_caches,
|
||||||
|
) = encoder.infer(
|
||||||
|
chunk,
|
||||||
|
chunk_length,
|
||||||
|
states,
|
||||||
|
conv_caches,
|
||||||
|
)
|
||||||
|
forward_output_chunk = forward_output[start_idx:end_idx]
|
||||||
|
assert torch.allclose(
|
||||||
|
infer_output_chunk,
|
||||||
|
forward_output_chunk,
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_forward_infer_consistency():
|
||||||
|
from emformer import Emformer
|
||||||
|
|
||||||
|
num_features = 80
|
||||||
|
output_dim = 1000
|
||||||
|
chunk_length = 8
|
||||||
|
num_chunks = 3
|
||||||
|
U = chunk_length * num_chunks
|
||||||
|
L, R = 128, 4
|
||||||
|
D = 256
|
||||||
|
num_encoder_layers = 2
|
||||||
|
K = 3
|
||||||
|
memory_sizes = [0, 3]
|
||||||
|
|
||||||
|
for M in memory_sizes:
|
||||||
|
model = Emformer(
|
||||||
|
num_features=num_features,
|
||||||
|
output_dim=output_dim,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
subsampling_factor=4,
|
||||||
|
d_model=D,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
cnn_module_kernel=K,
|
||||||
|
left_context_length=L,
|
||||||
|
right_context_length=R,
|
||||||
|
max_memory_size=M,
|
||||||
|
dropout=0.1,
|
||||||
|
vgg_frontend=False,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
x = torch.randn(1, U + R + 3, num_features)
|
||||||
|
x_lens = torch.tensor([x.size(1)])
|
||||||
|
|
||||||
|
# forward mode
|
||||||
|
forward_logits, _ = model(x, x_lens)
|
||||||
|
|
||||||
|
states = None
|
||||||
|
conv_caches = None
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
start_idx = chunk_idx * chunk_length
|
||||||
|
end_idx = start_idx + chunk_length
|
||||||
|
chunk = x[:, start_idx : end_idx + R + 3] # noqa
|
||||||
|
lengths = torch.tensor([chunk.size(1)])
|
||||||
|
(
|
||||||
|
infer_chunk_logits,
|
||||||
|
output_lengths,
|
||||||
|
states,
|
||||||
|
conv_caches,
|
||||||
|
) = model.infer(chunk, lengths, states, conv_caches)
|
||||||
|
forward_chunk_logits = forward_logits[
|
||||||
|
:, start_idx // 4 : end_idx // 4 # noqa
|
||||||
|
]
|
||||||
|
assert torch.allclose(
|
||||||
|
infer_chunk_logits,
|
||||||
|
forward_chunk_logits,
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_emformer_attention_forward()
|
test_emformer_attention_forward()
|
||||||
test_emformer_attention_infer()
|
test_emformer_attention_infer()
|
||||||
@ -375,3 +585,6 @@ if __name__ == "__main__":
|
|||||||
test_emformer_encoder_infer()
|
test_emformer_encoder_infer()
|
||||||
test_emformer_forward()
|
test_emformer_forward()
|
||||||
test_emformer_infer()
|
test_emformer_infer()
|
||||||
|
test_emformer_encoder_layer_forward_infer_consistency()
|
||||||
|
test_emformer_encoder_forward_infer_consistency()
|
||||||
|
test_emformer_forward_infer_consistency()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user