From 3e131891a2a466e279b6a7492361393d7f1a2093 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 8 Apr 2022 20:43:54 +0800 Subject: [PATCH] update test_emformer.py --- .../test_emformer.py | 66 ++++++++++--------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py index 4c9cbba9c..56cf2035e 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py @@ -61,16 +61,20 @@ def test_emformer_attention_infer(): left_context_key = torch.randn(L, B, D) left_context_val = torch.randn(L, B, D) - output_right_context_utterance, output_memory, next_key, next_val = \ - attention.infer( - utterance, - lengths, - right_context, - summary, - memory, - left_context_key, - left_context_val, - ) + ( + output_right_context_utterance, + output_memory, + next_key, + next_val, + ) = attention.infer( + utterance, + lengths, + right_context, + summary, + memory, + left_context_key, + left_context_val, + ) assert output_right_context_utterance.shape == (R + U, B, D) assert output_memory.shape == (S, B, D) assert next_key.shape == (L + U, B, D) @@ -98,7 +102,7 @@ def test_emformer_layer_forward(): chunk_length=chunk_length, left_context_length=L, max_memory_size=M, - ) + ) Q, KV = R + U + S, M + R + U utterance = torch.randn(U, B, D) @@ -141,7 +145,7 @@ def test_emformer_layer_infer(): chunk_length=chunk_length, left_context_length=L, max_memory_size=M, - ) + ) utterance = torch.randn(U, B, D) lengths = torch.randint(1, U + 1, (B,)) @@ -149,14 +153,18 @@ def test_emformer_layer_infer(): right_context = torch.randn(R, B, D) memory = torch.randn(M, B, D) state = None - output_utterance, output_right_context, output_memory, output_state = \ - layer.infer( - utterance, - lengths, - right_context, - memory, - state, - ) + ( + output_utterance, + output_right_context, + output_memory, + output_state, + ) = layer.infer( + utterance, + lengths, + right_context, + memory, + state, + ) assert output_utterance.shape == (U, B, D) assert output_right_context.shape == (R, B, D) if use_memory: @@ -200,9 +208,7 @@ def test_emformer_encoder_forward(): output, output_lengths = encoder(x, lengths) assert output.shape == (U, B, D) - assert torch.equal( - output_lengths, torch.clamp(lengths - R, min=0) - ) + assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0)) def test_emformer_encoder_infer(): @@ -236,8 +242,7 @@ def test_emformer_encoder_infer(): x = torch.randn(U + R, B, D) lengths = torch.randint(1, U + R + 1, (B,)) lengths[0] = U + R - output, output_lengths, states = \ - encoder.infer(x, lengths, states) + output, output_lengths, states = encoder.infer(x, lengths, states) assert output.shape == (U, B, D) assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0)) assert len(states) == num_encoder_layers @@ -253,6 +258,7 @@ def test_emformer_encoder_infer(): def test_emformer_forward(): from emformer import Emformer + num_features = 80 output_dim = 1000 chunk_length = 8 @@ -281,12 +287,13 @@ def test_emformer_forward(): assert logits.shape == (B, U // 4, output_dim) assert torch.equal( output_lengths, - torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0) + torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0), ) def test_emformer_infer(): from emformer import Emformer + num_features = 80 output_dim = 1000 chunk_length = 8 @@ -317,12 +324,11 @@ def test_emformer_infer(): x = torch.randn(B, U + R + 3, num_features) x_lens = torch.randint(1, U + R + 3 + 1, (B,)) x_lens[0] = U + R + 3 - logits, output_lengths, states = \ - model.infer(x, x_lens, states) + logits, output_lengths, states = model.infer(x, x_lens, states) assert logits.shape == (B, U // 4, output_dim) assert torch.equal( output_lengths, - torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0) + torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0), ) assert len(states) == num_encoder_layers for state in states: @@ -332,7 +338,7 @@ def test_emformer_infer(): assert state[2].shape == (L // 4, B, D) assert torch.equal( state[3], - U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]) + U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]), )