mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
minor change of test_emformer.py
This commit is contained in:
parent
e00ad8104e
commit
adcbb4076d
@ -43,8 +43,8 @@ def test_convolution_module_forward():
|
||||
right_context = torch.randn(R, B, D)
|
||||
|
||||
utterance, right_context = conv_module(utterance, right_context)
|
||||
assert utterance.shape == (U, B, D)
|
||||
assert right_context.shape == (R, B, D)
|
||||
assert utterance.shape == (U, B, D), utterance.shape
|
||||
assert right_context.shape == (R, B, D), right_context.shape
|
||||
|
||||
|
||||
def test_convolution_module_infer():
|
||||
@ -71,9 +71,9 @@ def test_convolution_module_infer():
|
||||
utterance, right_context, new_cache = conv_module.infer(
|
||||
utterance, right_context, cache
|
||||
)
|
||||
assert utterance.shape == (U, B, D)
|
||||
assert right_context.shape == (R, B, D)
|
||||
assert new_cache.shape == (B, D, kernel_size - 1)
|
||||
assert utterance.shape == (U, B, D), utterance.shape
|
||||
assert right_context.shape == (R, B, D), right_context.shape
|
||||
assert new_cache.shape == (B, D, kernel_size - 1), new_cache.shape
|
||||
|
||||
|
||||
def test_state_stack_unstack():
|
||||
|
Loading…
x
Reference in New Issue
Block a user