From df7919f4bf0f3c84977363551a65a61b093ca2dd Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 14 Apr 2022 19:16:30 +0800 Subject: [PATCH] update test functions for conv_emformer_transducer/emformer.py --- .../ASR/conv_emformer_transducer/emformer.py | 12 +- .../conv_emformer_transducer/test_emformer.py | 213 ++++++++++++++++++ 2 files changed, 215 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py index e9ce56aa7..14e106460 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py @@ -14,8 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# It is modified based on -# https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py. +# It is modified based on https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py. # noqa import math import warnings @@ -56,8 +55,6 @@ class EmformerAttention(nn.Module): Embedding dimension. nhead (int): Number of attention heads in each Emformer layer. - dropout (float, optional): - Dropout probability. (Default: 0.0) tanh_on_mem (bool, optional): If ``True``, applies tanh to memory elements. (Default: ``False``) negative_inf (float, optional): @@ -68,7 +65,6 @@ class EmformerAttention(nn.Module): self, embed_dim: int, nhead: int, - dropout: float = 0.0, tanh_on_mem: bool = False, negative_inf: float = -1e8, ): @@ -82,7 +78,6 @@ class EmformerAttention(nn.Module): self.embed_dim = embed_dim self.nhead = nhead - self.dropout = dropout self.tanh_on_mem = tanh_on_mem self.negative_inf = negative_inf @@ -154,9 +149,7 @@ class EmformerAttention(nn.Module): attention_probs = nn.functional.softmax( attention_weights_float, dim=-1 ).type_as(attention_weights) - attention_probs = nn.functional.dropout( - attention_probs, p=float(self.dropout), training=self.training - ) + return attention_probs def _forward_impl( @@ -481,7 +474,6 @@ class EmformerLayer(nn.Module): self.attention = EmformerAttention( embed_dim=d_model, nhead=nhead, - dropout=0.0, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, ) diff --git a/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py index 41e911e17..971abca97 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py @@ -366,6 +366,216 @@ def test_emformer_infer(): assert conv_cache.shape == (B, D, K - 1) +def test_emformer_encoder_layer_forward_infer_consistency(): + from emformer import EmformerEncoder + + chunk_length = 4 + num_chunks = 3 + U = chunk_length * num_chunks + L, R = 1, 2 + D = 256 + num_encoder_layers = 1 + memory_sizes = [0, 3] + K = 3 + + for M in memory_sizes: + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + dropout=0.1, + cnn_module_kernel=K, + causal=True, + ) + encoder.eval() + encoder_layer = encoder.emformer_layers[0] + + x = torch.randn(U + R, 1, D) + lengths = torch.tensor([U]) + right_context = encoder._gen_right_context(x) + utterance = x[: x.size(0) - R] + attention_mask = encoder._gen_attention_mask(utterance) + memory = ( + encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1 + ] + if encoder.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + ( + forward_output_utterance, + forward_output_right_context, + forward_output_memory, + ) = encoder_layer( + utterance, + lengths, + right_context, + memory, + attention_mask, + ) + + state = None + conv_cache = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = x[start_idx:end_idx] + chunk_right_context = x[end_idx : end_idx + R] # noqa + chunk_length = torch.tensor([chunk_length]) + chunk_memory = ( + encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1) + if encoder.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + ( + infer_output_chunk, + infer_right_context, + infer_output_memory, + state, + conv_cache, + ) = encoder_layer.infer( + chunk, + chunk_length, + chunk_right_context, + chunk_memory, + state, + conv_cache, + ) + forward_output_chunk = forward_output_utterance[start_idx:end_idx] + assert torch.allclose( + infer_output_chunk, + forward_output_chunk, + atol=1e-5, + rtol=0.0, + ) + + +def test_emformer_encoder_forward_infer_consistency(): + from emformer import EmformerEncoder + + chunk_length = 4 + num_chunks = 3 + U = chunk_length * num_chunks + L, R = 1, 2 + D = 256 + num_encoder_layers = 3 + K = 3 + memory_sizes = [0, 3] + + for M in memory_sizes: + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + dropout=0.1, + cnn_module_kernel=K, + causal=True, + ) + encoder.eval() + + x = torch.randn(U + R, 1, D) + lengths = torch.tensor([U + R]) + + forward_output, forward_output_lengths = encoder(x, lengths) + + states = None + conv_caches = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = x[start_idx : end_idx + R] # noqa + chunk_right_context = x[end_idx : end_idx + R] # noqa + chunk_length = torch.tensor([chunk_length]) + ( + infer_output_chunk, + infer_output_lengths, + states, + conv_caches, + ) = encoder.infer( + chunk, + chunk_length, + states, + conv_caches, + ) + forward_output_chunk = forward_output[start_idx:end_idx] + assert torch.allclose( + infer_output_chunk, + forward_output_chunk, + atol=1e-5, + rtol=0.0, + ) + + +def test_emformer_forward_infer_consistency(): + from emformer import Emformer + + num_features = 80 + output_dim = 1000 + chunk_length = 8 + num_chunks = 3 + U = chunk_length * num_chunks + L, R = 128, 4 + D = 256 + num_encoder_layers = 2 + K = 3 + memory_sizes = [0, 3] + + for M in memory_sizes: + model = Emformer( + num_features=num_features, + output_dim=output_dim, + chunk_length=chunk_length, + subsampling_factor=4, + d_model=D, + num_encoder_layers=num_encoder_layers, + cnn_module_kernel=K, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + dropout=0.1, + vgg_frontend=False, + causal=True, + ) + model.eval() + + x = torch.randn(1, U + R + 3, num_features) + x_lens = torch.tensor([x.size(1)]) + + # forward mode + forward_logits, _ = model(x, x_lens) + + states = None + conv_caches = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = x[:, start_idx : end_idx + R + 3] # noqa + lengths = torch.tensor([chunk.size(1)]) + ( + infer_chunk_logits, + output_lengths, + states, + conv_caches, + ) = model.infer(chunk, lengths, states, conv_caches) + forward_chunk_logits = forward_logits[ + :, start_idx // 4 : end_idx // 4 # noqa + ] + assert torch.allclose( + infer_chunk_logits, + forward_chunk_logits, + atol=1e-5, + rtol=0.0, + ) + + if __name__ == "__main__": test_emformer_attention_forward() test_emformer_attention_infer() @@ -375,3 +585,6 @@ if __name__ == "__main__": test_emformer_encoder_infer() test_emformer_forward() test_emformer_infer() + test_emformer_encoder_layer_forward_infer_consistency() + test_emformer_encoder_forward_infer_consistency() + test_emformer_forward_infer_consistency()