refactor Stream class

This commit is contained in:
yaozengwei 2022-06-09 13:00:22 +08:00
parent f8071e9373
commit acc8a36b5e

View File

@ -14,7 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import math
from typing import List, Optional, Tuple
import torch
from beam_search import HypothesisList
@ -46,53 +47,27 @@ class Stream(object):
def __init__(
self,
params: AttributeDict,
audio_sample: torch.Tensor,
ground_truth: str,
device: torch.device = torch.devive("cpu"),
device: torch.device = torch.device("cpu"),
) -> None:
"""
Args:
context_size:
Context size of the RNN-T decoder model.
decoding_method:
Decoding method. The possible values are:
- greedy_search
- modified_beam_search
params:
It's the return value of :func:`get_params`.
device:
The device to run this stream.
"""
self.feature_extractor = _create_streaming_feature_extractor()
# It contains a list of 1-D tensors representing the feature frames.
self.feature_frames: List[torch.Tensor] = []
self.num_fetched_frames = 0
# After calling `self.input_finished()`, we set this flag to True
self._done = False
self.device = device
# Containing attention caches and convolution caches
self.states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = None
# Initailize zero states.
past_len: int = 0
attn_caches = [
[
torch.zeros(params.memory_size, params.d_model, device=device),
torch.zeros(
params.left_context_length, params.d_model, device=device
),
torch.zeros(
params.left_context_length, params.d_model, device=device
),
]
for _ in range(params.num_encoder_layers)
]
conv_caches = [
torch.zeros(params.d_model, params.cnn_module_kernel, device=device)
for _ in range(params.num_encoder_layers)
]
self.states = [past_len, attn_caches, conv_caches]
self.init_states()
# It use different attributes for different decoding methods.
self.context_size = params.context_size
self.decoding_method = params.decoding_method
if params.decoding_method == "greedy_search":
self.hyp: Optional[List[int]] = None
self.decoder_out: Optional[torch.Tensor] = None
self.hyp = [params.blank_id] * params.context_size
elif params.decoding_method == "modified_beam_search":
self.hyps = HypothesisList()
else:
@ -100,70 +75,86 @@ class Stream(object):
f"Unsupported decoding method: {params.decoding_method}"
)
self.sample_rate = params.sample_rate
self.audio_sample = audio_sample
# Current index of sample
self.cur_index = 0
self.ground_truth: str = ""
self.feature: torch.Tensor = None
# Make sure all feature frames can be used.
# Add 2 here since we will drop the first and last after subsampling.
self.chunk_length = params.chunk_length
self.pad_length = (
params.right_context_length + 2 * params.subsampling_factor + 3
)
self.num_frames = 0
self.num_processed_frames = 0
# After all feature frames are processed, we set this flag to True
self._done = False
def set_feature(self, feature: torch.Tensor) -> None:
assert feature.dim == 2, feature.dim
self.num_frames = feature.size(0)
# tail padding
self.feature = torch.nn.functional.pad(
feature,
(0, 0, 0, self.pad_length),
mode="constant",
value=math.log(1e-10),
)
def set_ground_truth(self, ground_truth: str) -> None:
self.ground_truth = ground_truth
def accept_waveform(
self,
# sampling_rate: float,
# waveform: torch.Tensor,
) -> None:
"""Feed audio samples to the feature extractor and compute features
if there are enough samples available.
def init_states(self, params: AttributeDict) -> None:
attn_caches = [
[
torch.zeros(
params.memory_size, params.encoder_dim, device=self.device
),
torch.zeros(
params.left_context_length // params.subsampling_factor,
params.encoder_dim,
device=self.device,
),
torch.zeros(
params.left_context_length // params.subsampling_factor,
params.encoder_dim,
device=self.device,
),
]
for _ in range(params.num_encoder_layers)
]
conv_caches = [
torch.zeros(
params.encoder_dim,
params.cnn_module_kernel - 1,
device=self.device,
)
for _ in range(params.num_encoder_layers)
]
self.states = [attn_caches, conv_caches]
Caution:
The range of the audio samples should match the one used in the
training. That is, if you use the range [-1, 1] in the training, then
the input audio samples should also be normalized to [-1, 1].
Args
sampling_rate:
The sampling rate of the input audio samples. It is used for sanity
check to ensure that the input sampling rate equals to the one
used in the extractor. If they are not equal, then no resampling
will be performed; instead an error will be thrown.
waveform:
A 1-D torch tensor of dtype torch.float32 containing audio samples.
It should be on CPU.
"""
start = self.cur_index
end = self.cur_index + 1024
waveform = self.audio_sample[start:end]
self.cur_index = end
self.feature_extractor.accept_waveform(
sampling_rate=self.sampling_rate,
waveform=waveform,
def get_feature_chunk(self) -> Tuple[torch.Tensor, int]:
"""Get a chunk of feature frames."""
update_length = min(
self.num_frames - self.num_processed_frames, self.chunk_length
)
self._fetch_frames()
ret_length = update_length + self.pad_length
if waveform.numel() == 0:
self.input_finished()
ret_feature = self.feature[:ret_length]
# Cut off used frames.
self.feature = self.feature[update_length:]
def input_finished(self) -> None:
"""Signal that no more audio samples available and the feature
extractor should flush the buffered samples to compute frames.
"""
self.feature_extractor.input_finished()
self._fetch_frames()
self._done = True
self.num_processed_frames += update_length
if self.num_processed_frames >= self.num_frames:
self._done = True
return ret_feature, ret_length
@property
def done(self) -> bool:
"""Return True if `self.input_finished()` has been invoked"""
return self._done
def _fetch_frames(self) -> None:
"""Fetch frames from the feature extractor"""
while self.num_fetched_frames < self.feature_extractor.num_frames_ready:
frame = self.feature_extractor.get_frame(self.num_fetched_frames)
self.feature_frames.append(frame)
self.num_fetched_frames += 1
def decoding_result(self) -> List[int]:
"""Obtain current decoding result."""
if self.decoding_method == "greedy_search":