refactor Stream class

This commit is contained in:
yaozengwei 2022-06-09 13:00:22 +08:00
parent f8071e9373
commit acc8a36b5e

View File

@ -14,7 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional import math
from typing import List, Optional, Tuple
import torch import torch
from beam_search import HypothesisList from beam_search import HypothesisList
@ -46,53 +47,27 @@ class Stream(object):
def __init__( def __init__(
self, self,
params: AttributeDict, params: AttributeDict,
audio_sample: torch.Tensor, device: torch.device = torch.device("cpu"),
ground_truth: str,
device: torch.device = torch.devive("cpu"),
) -> None: ) -> None:
""" """
Args: Args:
context_size: params:
Context size of the RNN-T decoder model. It's the return value of :func:`get_params`.
decoding_method: device:
Decoding method. The possible values are: The device to run this stream.
- greedy_search
- modified_beam_search
""" """
self.feature_extractor = _create_streaming_feature_extractor() self.device = device
# It contains a list of 1-D tensors representing the feature frames.
self.feature_frames: List[torch.Tensor] = []
self.num_fetched_frames = 0
# After calling `self.input_finished()`, we set this flag to True
self._done = False
# Containing attention caches and convolution caches
self.states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = None
# Initailize zero states. # Initailize zero states.
past_len: int = 0 self.init_states()
attn_caches = [
[
torch.zeros(params.memory_size, params.d_model, device=device),
torch.zeros(
params.left_context_length, params.d_model, device=device
),
torch.zeros(
params.left_context_length, params.d_model, device=device
),
]
for _ in range(params.num_encoder_layers)
]
conv_caches = [
torch.zeros(params.d_model, params.cnn_module_kernel, device=device)
for _ in range(params.num_encoder_layers)
]
self.states = [past_len, attn_caches, conv_caches]
# It use different attributes for different decoding methods. # It use different attributes for different decoding methods.
self.context_size = params.context_size self.context_size = params.context_size
self.decoding_method = params.decoding_method self.decoding_method = params.decoding_method
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
self.hyp: Optional[List[int]] = None self.hyp = [params.blank_id] * params.context_size
self.decoder_out: Optional[torch.Tensor] = None
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
self.hyps = HypothesisList() self.hyps = HypothesisList()
else: else:
@ -100,70 +75,86 @@ class Stream(object):
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
self.sample_rate = params.sample_rate self.ground_truth: str = ""
self.audio_sample = audio_sample
# Current index of sample
self.cur_index = 0
self.feature: torch.Tensor = None
# Make sure all feature frames can be used.
# Add 2 here since we will drop the first and last after subsampling.
self.chunk_length = params.chunk_length
self.pad_length = (
params.right_context_length + 2 * params.subsampling_factor + 3
)
self.num_frames = 0
self.num_processed_frames = 0
# After all feature frames are processed, we set this flag to True
self._done = False
def set_feature(self, feature: torch.Tensor) -> None:
assert feature.dim == 2, feature.dim
self.num_frames = feature.size(0)
# tail padding
self.feature = torch.nn.functional.pad(
feature,
(0, 0, 0, self.pad_length),
mode="constant",
value=math.log(1e-10),
)
def set_ground_truth(self, ground_truth: str) -> None:
self.ground_truth = ground_truth self.ground_truth = ground_truth
def accept_waveform( def init_states(self, params: AttributeDict) -> None:
self, attn_caches = [
# sampling_rate: float, [
# waveform: torch.Tensor, torch.zeros(
) -> None: params.memory_size, params.encoder_dim, device=self.device
"""Feed audio samples to the feature extractor and compute features ),
if there are enough samples available. torch.zeros(
params.left_context_length // params.subsampling_factor,
params.encoder_dim,
device=self.device,
),
torch.zeros(
params.left_context_length // params.subsampling_factor,
params.encoder_dim,
device=self.device,
),
]
for _ in range(params.num_encoder_layers)
]
conv_caches = [
torch.zeros(
params.encoder_dim,
params.cnn_module_kernel - 1,
device=self.device,
)
for _ in range(params.num_encoder_layers)
]
self.states = [attn_caches, conv_caches]
Caution: def get_feature_chunk(self) -> Tuple[torch.Tensor, int]:
The range of the audio samples should match the one used in the """Get a chunk of feature frames."""
training. That is, if you use the range [-1, 1] in the training, then update_length = min(
the input audio samples should also be normalized to [-1, 1]. self.num_frames - self.num_processed_frames, self.chunk_length
Args
sampling_rate:
The sampling rate of the input audio samples. It is used for sanity
check to ensure that the input sampling rate equals to the one
used in the extractor. If they are not equal, then no resampling
will be performed; instead an error will be thrown.
waveform:
A 1-D torch tensor of dtype torch.float32 containing audio samples.
It should be on CPU.
"""
start = self.cur_index
end = self.cur_index + 1024
waveform = self.audio_sample[start:end]
self.cur_index = end
self.feature_extractor.accept_waveform(
sampling_rate=self.sampling_rate,
waveform=waveform,
) )
self._fetch_frames() ret_length = update_length + self.pad_length
if waveform.numel() == 0: ret_feature = self.feature[:ret_length]
self.input_finished() # Cut off used frames.
self.feature = self.feature[update_length:]
def input_finished(self) -> None: self.num_processed_frames += update_length
"""Signal that no more audio samples available and the feature if self.num_processed_frames >= self.num_frames:
extractor should flush the buffered samples to compute frames. self._done = True
"""
self.feature_extractor.input_finished() return ret_feature, ret_length
self._fetch_frames()
self._done = True
@property @property
def done(self) -> bool: def done(self) -> bool:
"""Return True if `self.input_finished()` has been invoked""" """Return True if `self.input_finished()` has been invoked"""
return self._done return self._done
def _fetch_frames(self) -> None:
"""Fetch frames from the feature extractor"""
while self.num_fetched_frames < self.feature_extractor.num_frames_ready:
frame = self.feature_extractor.get_frame(self.num_fetched_frames)
self.feature_frames.append(frame)
self.num_fetched_frames += 1
def decoding_result(self) -> List[int]: def decoding_result(self) -> List[int]:
"""Obtain current decoding result.""" """Obtain current decoding result."""
if self.decoding_method == "greedy_search": if self.decoding_method == "greedy_search":