diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py index 650c2538b..99fb0a877 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py @@ -43,8 +43,8 @@ def test_convolution_module_forward(): right_context = torch.randn(R, B, D) utterance, right_context = conv_module(utterance, right_context) - assert utterance.shape == (U, B, D) - assert right_context.shape == (R, B, D) + assert utterance.shape == (U, B, D), utterance.shape + assert right_context.shape == (R, B, D), right_context.shape def test_convolution_module_infer(): @@ -71,9 +71,9 @@ def test_convolution_module_infer(): utterance, right_context, new_cache = conv_module.infer( utterance, right_context, cache ) - assert utterance.shape == (U, B, D) - assert right_context.shape == (R, B, D) - assert new_cache.shape == (B, D, kernel_size - 1) + assert utterance.shape == (U, B, D), utterance.shape + assert right_context.shape == (R, B, D), right_context.shape + assert new_cache.shape == (B, D, kernel_size - 1), new_cache.shape def test_state_stack_unstack():