mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
add test function for conv module, and minor fix.
This commit is contained in:
parent
b24453434f
commit
4dc0af1144
@ -42,15 +42,16 @@ LOG_EPSILON = math.log(1e-10)
|
|||||||
def unstack_states(
|
def unstack_states(
|
||||||
states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
|
states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
|
||||||
) -> List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]:
|
) -> List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]:
|
||||||
# TODO: modify doc
|
|
||||||
"""Unstack the emformer state corresponding to a batch of utterances
|
"""Unstack the emformer state corresponding to a batch of utterances
|
||||||
into a list of states, were the i-th entry is the state from the i-th
|
into a list of states, were the i-th entry is the state from the i-th
|
||||||
utterance in the batch.
|
utterance in the batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
states:
|
states:
|
||||||
A list-of-list of tensors.
|
A list of tuples.
|
||||||
``len(states[0])`` and ``len(states[1])`` eqaul to number of layers.
|
``states[i][0]`` is the attention caches of i-th utterance.
|
||||||
|
``states[i][1]`` is the convolution caches of i-th utterance.
|
||||||
|
``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa
|
||||||
"""
|
"""
|
||||||
|
|
||||||
attn_caches, conv_caches = states
|
attn_caches, conv_caches = states
|
||||||
@ -146,7 +147,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
right_context_length (int):
|
right_context_length (int):
|
||||||
Length of right context.
|
Length of right context.
|
||||||
channels (int):
|
channels (int):
|
||||||
The number of channels of conv layers.
|
The number of input channels and output channels of conv layers.
|
||||||
kernel_size (int):
|
kernel_size (int):
|
||||||
Kernerl size of conv layers.
|
Kernerl size of conv layers.
|
||||||
bias (bool):
|
bias (bool):
|
||||||
@ -162,9 +163,9 @@ class ConvolutionModule(nn.Module):
|
|||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct an ConvolutionModule object."""
|
"""Construct an ConvolutionModule object."""
|
||||||
super(ConvolutionModule, self).__init__()
|
super().__init__()
|
||||||
# kernerl_size should be a odd number for 'SAME' padding
|
# kernerl_size should be an odd number for 'SAME' padding
|
||||||
assert (kernel_size - 1) % 2 == 0
|
assert (kernel_size - 1) % 2 == 0, kernel_size
|
||||||
|
|
||||||
self.chunk_length = chunk_length
|
self.chunk_length = chunk_length
|
||||||
self.right_context_length = right_context_length
|
self.right_context_length = right_context_length
|
||||||
|
@ -53,7 +53,7 @@ class Stream(object):
|
|||||||
# Initailize zero states.
|
# Initailize zero states.
|
||||||
self.init_states(params)
|
self.init_states(params)
|
||||||
|
|
||||||
# It use different attributes for different decoding methods.
|
# It uses different attributes for different decoding methods.
|
||||||
self.context_size = params.context_size
|
self.context_size = params.context_size
|
||||||
self.decoding_method = params.decoding_method
|
self.decoding_method = params.decoding_method
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
@ -72,7 +72,7 @@ class Stream(object):
|
|||||||
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
|
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
|
||||||
k2.RnntDecodingStream(decoding_graph)
|
k2.RnntDecodingStream(decoding_graph)
|
||||||
)
|
)
|
||||||
self.hyp: List[int] = None
|
self.hyp: Optional[List[int]] = None
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
@ -134,10 +134,14 @@ class Stream(object):
|
|||||||
)
|
)
|
||||||
for _ in range(params.num_encoder_layers)
|
for _ in range(params.num_encoder_layers)
|
||||||
]
|
]
|
||||||
self.states = [attn_caches, conv_caches]
|
self.states = (attn_caches, conv_caches)
|
||||||
|
|
||||||
def get_feature_chunk(self) -> Tuple[torch.Tensor, int]:
|
def get_feature_chunk(self) -> torch.Tensor:
|
||||||
"""Get a chunk of feature frames."""
|
"""Get a chunk of feature frames.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (ret_length, feature_dim).
|
||||||
|
"""
|
||||||
update_length = min(
|
update_length = min(
|
||||||
self.num_frames - self.num_processed_frames, self.chunk_length
|
self.num_frames - self.num_processed_frames, self.chunk_length
|
||||||
)
|
)
|
||||||
@ -153,11 +157,11 @@ class Stream(object):
|
|||||||
if self.num_processed_frames >= self.num_frames:
|
if self.num_processed_frames >= self.num_frames:
|
||||||
self._done = True
|
self._done = True
|
||||||
|
|
||||||
return ret_feature, ret_length
|
return ret_feature
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def done(self) -> bool:
|
def done(self) -> bool:
|
||||||
"""Return True if `self.input_finished()` has been invoked"""
|
"""Return True if all feature frames are processed."""
|
||||||
return self._done
|
return self._done
|
||||||
|
|
||||||
def decoding_result(self) -> List[int]:
|
def decoding_result(self) -> List[int]:
|
||||||
|
@ -245,8 +245,9 @@ def greedy_search(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
streams: List[Stream],
|
streams: List[Stream],
|
||||||
) -> List[List[int]]:
|
) -> None:
|
||||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
The transducer model.
|
The transducer model.
|
||||||
@ -270,10 +271,9 @@ def greedy_search(
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
# decoder_out is of shape (N, decoder_out_dim)
|
# decoder_out is of shape (batch_size, 1, decoder_out_dim)
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
# logging.info(f"decoder_out shape : {decoder_out.shape}")
|
|
||||||
|
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
|
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
|
||||||
@ -427,7 +427,7 @@ def fast_beam_search_one_best(
|
|||||||
beam: float,
|
beam: float,
|
||||||
max_states: int,
|
max_states: int,
|
||||||
max_contexts: int,
|
max_contexts: int,
|
||||||
) -> List[List[int]]:
|
) -> None:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
A lattice is first obtained using modified beam search, and then
|
A lattice is first obtained using modified beam search, and then
|
||||||
@ -449,8 +449,6 @@ def fast_beam_search_one_best(
|
|||||||
Max states per stream per frame.
|
Max states per stream per frame.
|
||||||
max_contexts:
|
max_contexts:
|
||||||
Max contexts pre stream per frame.
|
Max contexts pre stream per frame.
|
||||||
Returns:
|
|
||||||
Return the decoded result.
|
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
@ -543,7 +541,8 @@ def decode_one_chunk(
|
|||||||
# before calling `stream.get_feature_chunk()`
|
# before calling `stream.get_feature_chunk()`
|
||||||
# since `stream.num_processed_frames` would be updated
|
# since `stream.num_processed_frames` would be updated
|
||||||
num_processed_frames_list.append(stream.num_processed_frames)
|
num_processed_frames_list.append(stream.num_processed_frames)
|
||||||
feature, feature_len = stream.get_feature_chunk()
|
feature = stream.get_feature_chunk()
|
||||||
|
feature_len = feature.size(0)
|
||||||
feature_list.append(feature)
|
feature_list.append(feature)
|
||||||
feature_len_list.append(feature_len)
|
feature_len_list.append(feature_len)
|
||||||
state_list.append(stream.states)
|
state_list.append(stream.states)
|
||||||
@ -809,7 +808,6 @@ def main():
|
|||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
# Note: params.decoding_method is currently not used.
|
|
||||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
|
@ -1,9 +1,64 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from emformer import Emformer, stack_states, unstack_states
|
||||||
|
|
||||||
|
|
||||||
|
def test_convolution_module_forward():
|
||||||
|
from emformer import ConvolutionModule
|
||||||
|
|
||||||
|
B, D = 2, 256
|
||||||
|
chunk_length = 4
|
||||||
|
right_context_length = 2
|
||||||
|
num_chunks = 3
|
||||||
|
U = num_chunks * chunk_length
|
||||||
|
R = num_chunks * right_context_length
|
||||||
|
kernel_size = 31
|
||||||
|
conv_module = ConvolutionModule(
|
||||||
|
chunk_length,
|
||||||
|
right_context_length,
|
||||||
|
D,
|
||||||
|
kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
utterance = torch.randn(U, B, D)
|
||||||
|
right_context = torch.randn(R, B, D)
|
||||||
|
|
||||||
|
utterance, right_context = conv_module(utterance, right_context)
|
||||||
|
assert utterance.shape == (U, B, D)
|
||||||
|
assert right_context.shape == (R, B, D)
|
||||||
|
|
||||||
|
|
||||||
|
def test_convolution_module_infer():
|
||||||
|
from emformer import ConvolutionModule
|
||||||
|
|
||||||
|
B, D = 2, 256
|
||||||
|
chunk_length = 4
|
||||||
|
right_context_length = 2
|
||||||
|
num_chunks = 1
|
||||||
|
U = num_chunks * chunk_length
|
||||||
|
R = num_chunks * right_context_length
|
||||||
|
kernel_size = 31
|
||||||
|
conv_module = ConvolutionModule(
|
||||||
|
chunk_length,
|
||||||
|
right_context_length,
|
||||||
|
D,
|
||||||
|
kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
utterance = torch.randn(U, B, D)
|
||||||
|
right_context = torch.randn(R, B, D)
|
||||||
|
cache = torch.randn(B, D, kernel_size - 1)
|
||||||
|
|
||||||
|
utterance, right_context, new_cache = conv_module.infer(
|
||||||
|
utterance, right_context, cache
|
||||||
|
)
|
||||||
|
assert utterance.shape == (U, B, D)
|
||||||
|
assert right_context.shape == (R, B, D)
|
||||||
|
assert new_cache.shape == (B, D, kernel_size - 1)
|
||||||
|
|
||||||
|
|
||||||
def test_state_stack_unstack():
|
def test_state_stack_unstack():
|
||||||
from emformer import Emformer, stack_states, unstack_states
|
|
||||||
|
|
||||||
num_features = 80
|
num_features = 80
|
||||||
chunk_length = 32
|
chunk_length = 32
|
||||||
encoder_dim = 512
|
encoder_dim = 512
|
||||||
@ -62,8 +117,6 @@ def test_state_stack_unstack():
|
|||||||
|
|
||||||
def test_torchscript_consistency_infer():
|
def test_torchscript_consistency_infer():
|
||||||
r"""Verify that scripting Emformer does not change the behavior of method `infer`.""" # noqa
|
r"""Verify that scripting Emformer does not change the behavior of method `infer`.""" # noqa
|
||||||
from emformer import Emformer
|
|
||||||
|
|
||||||
num_features = 80
|
num_features = 80
|
||||||
chunk_length = 32
|
chunk_length = 32
|
||||||
encoder_dim = 512
|
encoder_dim = 512
|
||||||
@ -118,5 +171,7 @@ def test_torchscript_consistency_infer():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
test_convolution_module_forward()
|
||||||
|
test_convolution_module_infer()
|
||||||
test_state_stack_unstack()
|
test_state_stack_unstack()
|
||||||
test_torchscript_consistency_infer()
|
test_torchscript_consistency_infer()
|
||||||
|
@ -619,7 +619,7 @@ def compute_loss(
|
|||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute RNN-T loss given the model and its inputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params:
|
params:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user