minor change of test_emformer.py

This commit is contained in:
yaozengwei 2022-06-12 18:11:36 +08:00
parent e00ad8104e
commit adcbb4076d

View File

@ -43,8 +43,8 @@ def test_convolution_module_forward():
right_context = torch.randn(R, B, D)
utterance, right_context = conv_module(utterance, right_context)
assert utterance.shape == (U, B, D)
assert right_context.shape == (R, B, D)
assert utterance.shape == (U, B, D), utterance.shape
assert right_context.shape == (R, B, D), right_context.shape
def test_convolution_module_infer():
@ -71,9 +71,9 @@ def test_convolution_module_infer():
utterance, right_context, new_cache = conv_module.infer(
utterance, right_context, cache
)
assert utterance.shape == (U, B, D)
assert right_context.shape == (R, B, D)
assert new_cache.shape == (B, D, kernel_size - 1)
assert utterance.shape == (U, B, D), utterance.shape
assert right_context.shape == (R, B, D), right_context.shape
assert new_cache.shape == (B, D, kernel_size - 1), new_cache.shape
def test_state_stack_unstack():