mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
add test function for conv module, and minor fix.
This commit is contained in:
parent
b24453434f
commit
4dc0af1144
@ -42,15 +42,16 @@ LOG_EPSILON = math.log(1e-10)
|
||||
def unstack_states(
|
||||
states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
|
||||
) -> List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]:
|
||||
# TODO: modify doc
|
||||
"""Unstack the emformer state corresponding to a batch of utterances
|
||||
into a list of states, were the i-th entry is the state from the i-th
|
||||
utterance in the batch.
|
||||
|
||||
Args:
|
||||
states:
|
||||
A list-of-list of tensors.
|
||||
``len(states[0])`` and ``len(states[1])`` eqaul to number of layers.
|
||||
A list of tuples.
|
||||
``states[i][0]`` is the attention caches of i-th utterance.
|
||||
``states[i][1]`` is the convolution caches of i-th utterance.
|
||||
``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa
|
||||
"""
|
||||
|
||||
attn_caches, conv_caches = states
|
||||
@ -146,7 +147,7 @@ class ConvolutionModule(nn.Module):
|
||||
right_context_length (int):
|
||||
Length of right context.
|
||||
channels (int):
|
||||
The number of channels of conv layers.
|
||||
The number of input channels and output channels of conv layers.
|
||||
kernel_size (int):
|
||||
Kernerl size of conv layers.
|
||||
bias (bool):
|
||||
@ -162,9 +163,9 @@ class ConvolutionModule(nn.Module):
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
super().__init__()
|
||||
# kernerl_size should be an odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0, kernel_size
|
||||
|
||||
self.chunk_length = chunk_length
|
||||
self.right_context_length = right_context_length
|
||||
|
@ -53,7 +53,7 @@ class Stream(object):
|
||||
# Initailize zero states.
|
||||
self.init_states(params)
|
||||
|
||||
# It use different attributes for different decoding methods.
|
||||
# It uses different attributes for different decoding methods.
|
||||
self.context_size = params.context_size
|
||||
self.decoding_method = params.decoding_method
|
||||
if params.decoding_method == "greedy_search":
|
||||
@ -72,7 +72,7 @@ class Stream(object):
|
||||
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
|
||||
k2.RnntDecodingStream(decoding_graph)
|
||||
)
|
||||
self.hyp: List[int] = None
|
||||
self.hyp: Optional[List[int]] = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
@ -134,10 +134,14 @@ class Stream(object):
|
||||
)
|
||||
for _ in range(params.num_encoder_layers)
|
||||
]
|
||||
self.states = [attn_caches, conv_caches]
|
||||
self.states = (attn_caches, conv_caches)
|
||||
|
||||
def get_feature_chunk(self) -> Tuple[torch.Tensor, int]:
|
||||
"""Get a chunk of feature frames."""
|
||||
def get_feature_chunk(self) -> torch.Tensor:
|
||||
"""Get a chunk of feature frames.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (ret_length, feature_dim).
|
||||
"""
|
||||
update_length = min(
|
||||
self.num_frames - self.num_processed_frames, self.chunk_length
|
||||
)
|
||||
@ -153,11 +157,11 @@ class Stream(object):
|
||||
if self.num_processed_frames >= self.num_frames:
|
||||
self._done = True
|
||||
|
||||
return ret_feature, ret_length
|
||||
return ret_feature
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""Return True if `self.input_finished()` has been invoked"""
|
||||
"""Return True if all feature frames are processed."""
|
||||
return self._done
|
||||
|
||||
def decoding_result(self) -> List[int]:
|
||||
|
@ -245,8 +245,9 @@ def greedy_search(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
streams: List[Stream],
|
||||
) -> List[List[int]]:
|
||||
) -> None:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
@ -270,10 +271,9 @@ def greedy_search(
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
# decoder_out is of shape (N, decoder_out_dim)
|
||||
# decoder_out is of shape (batch_size, 1, decoder_out_dim)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# logging.info(f"decoder_out shape : {decoder_out.shape}")
|
||||
|
||||
for t in range(T):
|
||||
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
|
||||
@ -427,7 +427,7 @@ def fast_beam_search_one_best(
|
||||
beam: float,
|
||||
max_states: int,
|
||||
max_contexts: int,
|
||||
) -> List[List[int]]:
|
||||
) -> None:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
A lattice is first obtained using modified beam search, and then
|
||||
@ -449,8 +449,6 @@ def fast_beam_search_one_best(
|
||||
Max states per stream per frame.
|
||||
max_contexts:
|
||||
Max contexts pre stream per frame.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
@ -543,7 +541,8 @@ def decode_one_chunk(
|
||||
# before calling `stream.get_feature_chunk()`
|
||||
# since `stream.num_processed_frames` would be updated
|
||||
num_processed_frames_list.append(stream.num_processed_frames)
|
||||
feature, feature_len = stream.get_feature_chunk()
|
||||
feature = stream.get_feature_chunk()
|
||||
feature_len = feature.size(0)
|
||||
feature_list.append(feature)
|
||||
feature_len_list.append(feature_len)
|
||||
state_list.append(stream.states)
|
||||
@ -809,7 +808,6 @@ def main():
|
||||
"fast_beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
# Note: params.decoding_method is currently not used.
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
|
@ -1,9 +1,64 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
from emformer import Emformer, stack_states, unstack_states
|
||||
|
||||
|
||||
def test_convolution_module_forward():
|
||||
from emformer import ConvolutionModule
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
R = num_chunks * right_context_length
|
||||
kernel_size = 31
|
||||
conv_module = ConvolutionModule(
|
||||
chunk_length,
|
||||
right_context_length,
|
||||
D,
|
||||
kernel_size,
|
||||
)
|
||||
|
||||
utterance = torch.randn(U, B, D)
|
||||
right_context = torch.randn(R, B, D)
|
||||
|
||||
utterance, right_context = conv_module(utterance, right_context)
|
||||
assert utterance.shape == (U, B, D)
|
||||
assert right_context.shape == (R, B, D)
|
||||
|
||||
|
||||
def test_convolution_module_infer():
|
||||
from emformer import ConvolutionModule
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
num_chunks = 1
|
||||
U = num_chunks * chunk_length
|
||||
R = num_chunks * right_context_length
|
||||
kernel_size = 31
|
||||
conv_module = ConvolutionModule(
|
||||
chunk_length,
|
||||
right_context_length,
|
||||
D,
|
||||
kernel_size,
|
||||
)
|
||||
|
||||
utterance = torch.randn(U, B, D)
|
||||
right_context = torch.randn(R, B, D)
|
||||
cache = torch.randn(B, D, kernel_size - 1)
|
||||
|
||||
utterance, right_context, new_cache = conv_module.infer(
|
||||
utterance, right_context, cache
|
||||
)
|
||||
assert utterance.shape == (U, B, D)
|
||||
assert right_context.shape == (R, B, D)
|
||||
assert new_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
def test_state_stack_unstack():
|
||||
from emformer import Emformer, stack_states, unstack_states
|
||||
|
||||
num_features = 80
|
||||
chunk_length = 32
|
||||
encoder_dim = 512
|
||||
@ -62,8 +117,6 @@ def test_state_stack_unstack():
|
||||
|
||||
def test_torchscript_consistency_infer():
|
||||
r"""Verify that scripting Emformer does not change the behavior of method `infer`.""" # noqa
|
||||
from emformer import Emformer
|
||||
|
||||
num_features = 80
|
||||
chunk_length = 32
|
||||
encoder_dim = 512
|
||||
@ -118,5 +171,7 @@ def test_torchscript_consistency_infer():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_convolution_module_forward()
|
||||
test_convolution_module_infer()
|
||||
test_state_stack_unstack()
|
||||
test_torchscript_consistency_infer()
|
||||
|
@ -619,7 +619,7 @@ def compute_loss(
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute CTC loss given the model and its inputs.
|
||||
Compute RNN-T loss given the model and its inputs.
|
||||
|
||||
Args:
|
||||
params:
|
||||
|
Loading…
x
Reference in New Issue
Block a user