refactor Stream class

This commit is contained in:
yaozengwei 2022-06-09 13:00:22 +08:00
parent f8071e9373
commit acc8a36b5e

View File

@ -14,7 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional import math
from typing import List, Optional, Tuple
import torch import torch
from beam_search import HypothesisList from beam_search import HypothesisList
@ -46,53 +47,27 @@ class Stream(object):
def __init__( def __init__(
self, self,
params: AttributeDict, params: AttributeDict,
audio_sample: torch.Tensor, device: torch.device = torch.device("cpu"),
ground_truth: str,
device: torch.device = torch.devive("cpu"),
) -> None: ) -> None:
""" """
Args: Args:
context_size: params:
Context size of the RNN-T decoder model. It's the return value of :func:`get_params`.
decoding_method: device:
Decoding method. The possible values are: The device to run this stream.
- greedy_search
- modified_beam_search
""" """
self.feature_extractor = _create_streaming_feature_extractor() self.device = device
# It contains a list of 1-D tensors representing the feature frames.
self.feature_frames: List[torch.Tensor] = []
self.num_fetched_frames = 0
# After calling `self.input_finished()`, we set this flag to True
self._done = False
# Containing attention caches and convolution caches
self.states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = None
# Initailize zero states. # Initailize zero states.
past_len: int = 0 self.init_states()
attn_caches = [
[
torch.zeros(params.memory_size, params.d_model, device=device),
torch.zeros(
params.left_context_length, params.d_model, device=device
),
torch.zeros(
params.left_context_length, params.d_model, device=device
),
]
for _ in range(params.num_encoder_layers)
]
conv_caches = [
torch.zeros(params.d_model, params.cnn_module_kernel, device=device)
for _ in range(params.num_encoder_layers)
]
self.states = [past_len, attn_caches, conv_caches]
# It use different attributes for different decoding methods. # It use different attributes for different decoding methods.
self.context_size = params.context_size self.context_size = params.context_size
self.decoding_method = params.decoding_method self.decoding_method = params.decoding_method
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
self.hyp: Optional[List[int]] = None self.hyp = [params.blank_id] * params.context_size
self.decoder_out: Optional[torch.Tensor] = None
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
self.hyps = HypothesisList() self.hyps = HypothesisList()
else: else:
@ -100,70 +75,86 @@ class Stream(object):
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
self.sample_rate = params.sample_rate self.ground_truth: str = ""
self.audio_sample = audio_sample
# Current index of sample
self.cur_index = 0
self.feature: torch.Tensor = None
# Make sure all feature frames can be used.
# Add 2 here since we will drop the first and last after subsampling.
self.chunk_length = params.chunk_length
self.pad_length = (
params.right_context_length + 2 * params.subsampling_factor + 3
)
self.num_frames = 0
self.num_processed_frames = 0
# After all feature frames are processed, we set this flag to True
self._done = False
def set_feature(self, feature: torch.Tensor) -> None:
assert feature.dim == 2, feature.dim
self.num_frames = feature.size(0)
# tail padding
self.feature = torch.nn.functional.pad(
feature,
(0, 0, 0, self.pad_length),
mode="constant",
value=math.log(1e-10),
)
def set_ground_truth(self, ground_truth: str) -> None:
self.ground_truth = ground_truth self.ground_truth = ground_truth
def accept_waveform( def init_states(self, params: AttributeDict) -> None:
self, attn_caches = [
# sampling_rate: float, [
# waveform: torch.Tensor, torch.zeros(
) -> None: params.memory_size, params.encoder_dim, device=self.device
"""Feed audio samples to the feature extractor and compute features ),
if there are enough samples available. torch.zeros(
params.left_context_length // params.subsampling_factor,
Caution: params.encoder_dim,
The range of the audio samples should match the one used in the device=self.device,
training. That is, if you use the range [-1, 1] in the training, then ),
the input audio samples should also be normalized to [-1, 1]. torch.zeros(
params.left_context_length // params.subsampling_factor,
Args params.encoder_dim,
sampling_rate: device=self.device,
The sampling rate of the input audio samples. It is used for sanity ),
check to ensure that the input sampling rate equals to the one ]
used in the extractor. If they are not equal, then no resampling for _ in range(params.num_encoder_layers)
will be performed; instead an error will be thrown. ]
waveform: conv_caches = [
A 1-D torch tensor of dtype torch.float32 containing audio samples. torch.zeros(
It should be on CPU. params.encoder_dim,
""" params.cnn_module_kernel - 1,
start = self.cur_index device=self.device,
end = self.cur_index + 1024
waveform = self.audio_sample[start:end]
self.cur_index = end
self.feature_extractor.accept_waveform(
sampling_rate=self.sampling_rate,
waveform=waveform,
) )
self._fetch_frames() for _ in range(params.num_encoder_layers)
]
self.states = [attn_caches, conv_caches]
if waveform.numel() == 0: def get_feature_chunk(self) -> Tuple[torch.Tensor, int]:
self.input_finished() """Get a chunk of feature frames."""
update_length = min(
self.num_frames - self.num_processed_frames, self.chunk_length
)
ret_length = update_length + self.pad_length
def input_finished(self) -> None: ret_feature = self.feature[:ret_length]
"""Signal that no more audio samples available and the feature # Cut off used frames.
extractor should flush the buffered samples to compute frames. self.feature = self.feature[update_length:]
"""
self.feature_extractor.input_finished() self.num_processed_frames += update_length
self._fetch_frames() if self.num_processed_frames >= self.num_frames:
self._done = True self._done = True
return ret_feature, ret_length
@property @property
def done(self) -> bool: def done(self) -> bool:
"""Return True if `self.input_finished()` has been invoked""" """Return True if `self.input_finished()` has been invoked"""
return self._done return self._done
def _fetch_frames(self) -> None:
"""Fetch frames from the feature extractor"""
while self.num_fetched_frames < self.feature_extractor.num_frames_ready:
frame = self.feature_extractor.get_frame(self.num_fetched_frames)
self.feature_frames.append(frame)
self.num_fetched_frames += 1
def decoding_result(self) -> List[int]: def decoding_result(self) -> List[int]:
"""Obtain current decoding result.""" """Obtain current decoding result."""
if self.decoding_method == "greedy_search": if self.decoding_method == "greedy_search":