add test function for conv module, and minor fix.

This commit is contained in:
yaozengwei 2022-06-12 18:01:21 +08:00
parent b24453434f
commit 4dc0af1144
5 changed files with 85 additions and 27 deletions

View File

@ -42,15 +42,16 @@ LOG_EPSILON = math.log(1e-10)
def unstack_states(
states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
) -> List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]:
# TODO: modify doc
"""Unstack the emformer state corresponding to a batch of utterances
into a list of states, were the i-th entry is the state from the i-th
utterance in the batch.
Args:
states:
A list-of-list of tensors.
``len(states[0])`` and ``len(states[1])`` eqaul to number of layers.
A list of tuples.
``states[i][0]`` is the attention caches of i-th utterance.
``states[i][1]`` is the convolution caches of i-th utterance.
``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa
"""
attn_caches, conv_caches = states
@ -146,7 +147,7 @@ class ConvolutionModule(nn.Module):
right_context_length (int):
Length of right context.
channels (int):
The number of channels of conv layers.
The number of input channels and output channels of conv layers.
kernel_size (int):
Kernerl size of conv layers.
bias (bool):
@ -162,9 +163,9 @@ class ConvolutionModule(nn.Module):
bias: bool = True,
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
super().__init__()
# kernerl_size should be an odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0, kernel_size
self.chunk_length = chunk_length
self.right_context_length = right_context_length

View File

@ -53,7 +53,7 @@ class Stream(object):
# Initailize zero states.
self.init_states(params)
# It use different attributes for different decoding methods.
# It uses different attributes for different decoding methods.
self.context_size = params.context_size
self.decoding_method = params.decoding_method
if params.decoding_method == "greedy_search":
@ -72,7 +72,7 @@ class Stream(object):
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
k2.RnntDecodingStream(decoding_graph)
)
self.hyp: List[int] = None
self.hyp: Optional[List[int]] = None
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
@ -134,10 +134,14 @@ class Stream(object):
)
for _ in range(params.num_encoder_layers)
]
self.states = [attn_caches, conv_caches]
self.states = (attn_caches, conv_caches)
def get_feature_chunk(self) -> Tuple[torch.Tensor, int]:
"""Get a chunk of feature frames."""
def get_feature_chunk(self) -> torch.Tensor:
"""Get a chunk of feature frames.
Returns:
A tensor of shape (ret_length, feature_dim).
"""
update_length = min(
self.num_frames - self.num_processed_frames, self.chunk_length
)
@ -153,11 +157,11 @@ class Stream(object):
if self.num_processed_frames >= self.num_frames:
self._done = True
return ret_feature, ret_length
return ret_feature
@property
def done(self) -> bool:
"""Return True if `self.input_finished()` has been invoked"""
"""Return True if all feature frames are processed."""
return self._done
def decoding_result(self) -> List[int]:

View File

@ -245,8 +245,9 @@ def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[Stream],
) -> List[List[int]]:
) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
@ -270,10 +271,9 @@ def greedy_search(
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
# decoder_out is of shape (batch_size, 1, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# logging.info(f"decoder_out shape : {decoder_out.shape}")
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
@ -427,7 +427,7 @@ def fast_beam_search_one_best(
beam: float,
max_states: int,
max_contexts: int,
) -> List[List[int]]:
) -> None:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
@ -449,8 +449,6 @@ def fast_beam_search_one_best(
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
@ -543,7 +541,8 @@ def decode_one_chunk(
# before calling `stream.get_feature_chunk()`
# since `stream.num_processed_frames` would be updated
num_processed_frames_list.append(stream.num_processed_frames)
feature, feature_len = stream.get_feature_chunk()
feature = stream.get_feature_chunk()
feature_len = feature.size(0)
feature_list.append(feature)
feature_len_list.append(feature_len)
state_list.append(stream.states)
@ -809,7 +808,6 @@ def main():
"fast_beam_search",
"modified_beam_search",
)
# Note: params.decoding_method is currently not used.
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:

View File

@ -1,9 +1,64 @@
#!/usr/bin/env python3
import torch
from emformer import Emformer, stack_states, unstack_states
def test_convolution_module_forward():
from emformer import ConvolutionModule
B, D = 2, 256
chunk_length = 4
right_context_length = 2
num_chunks = 3
U = num_chunks * chunk_length
R = num_chunks * right_context_length
kernel_size = 31
conv_module = ConvolutionModule(
chunk_length,
right_context_length,
D,
kernel_size,
)
utterance = torch.randn(U, B, D)
right_context = torch.randn(R, B, D)
utterance, right_context = conv_module(utterance, right_context)
assert utterance.shape == (U, B, D)
assert right_context.shape == (R, B, D)
def test_convolution_module_infer():
from emformer import ConvolutionModule
B, D = 2, 256
chunk_length = 4
right_context_length = 2
num_chunks = 1
U = num_chunks * chunk_length
R = num_chunks * right_context_length
kernel_size = 31
conv_module = ConvolutionModule(
chunk_length,
right_context_length,
D,
kernel_size,
)
utterance = torch.randn(U, B, D)
right_context = torch.randn(R, B, D)
cache = torch.randn(B, D, kernel_size - 1)
utterance, right_context, new_cache = conv_module.infer(
utterance, right_context, cache
)
assert utterance.shape == (U, B, D)
assert right_context.shape == (R, B, D)
assert new_cache.shape == (B, D, kernel_size - 1)
def test_state_stack_unstack():
from emformer import Emformer, stack_states, unstack_states
num_features = 80
chunk_length = 32
encoder_dim = 512
@ -62,8 +117,6 @@ def test_state_stack_unstack():
def test_torchscript_consistency_infer():
r"""Verify that scripting Emformer does not change the behavior of method `infer`.""" # noqa
from emformer import Emformer
num_features = 80
chunk_length = 32
encoder_dim = 512
@ -118,5 +171,7 @@ def test_torchscript_consistency_infer():
if __name__ == "__main__":
test_convolution_module_forward()
test_convolution_module_infer()
test_state_stack_unstack()
test_torchscript_consistency_infer()

View File

@ -619,7 +619,7 @@ def compute_loss(
warmup: float = 1.0,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Compute RNN-T loss given the model and its inputs.
Args:
params: