mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
refactor Stream class
This commit is contained in:
parent
f8071e9373
commit
acc8a36b5e
@ -14,7 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from beam_search import HypothesisList
|
||||
@ -46,53 +47,27 @@ class Stream(object):
|
||||
def __init__(
|
||||
self,
|
||||
params: AttributeDict,
|
||||
audio_sample: torch.Tensor,
|
||||
ground_truth: str,
|
||||
device: torch.device = torch.devive("cpu"),
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
context_size:
|
||||
Context size of the RNN-T decoder model.
|
||||
decoding_method:
|
||||
Decoding method. The possible values are:
|
||||
- greedy_search
|
||||
- modified_beam_search
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
device:
|
||||
The device to run this stream.
|
||||
"""
|
||||
self.feature_extractor = _create_streaming_feature_extractor()
|
||||
# It contains a list of 1-D tensors representing the feature frames.
|
||||
self.feature_frames: List[torch.Tensor] = []
|
||||
self.num_fetched_frames = 0
|
||||
|
||||
# After calling `self.input_finished()`, we set this flag to True
|
||||
self._done = False
|
||||
self.device = device
|
||||
|
||||
# Containing attention caches and convolution caches
|
||||
self.states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = None
|
||||
# Initailize zero states.
|
||||
past_len: int = 0
|
||||
attn_caches = [
|
||||
[
|
||||
torch.zeros(params.memory_size, params.d_model, device=device),
|
||||
torch.zeros(
|
||||
params.left_context_length, params.d_model, device=device
|
||||
),
|
||||
torch.zeros(
|
||||
params.left_context_length, params.d_model, device=device
|
||||
),
|
||||
]
|
||||
for _ in range(params.num_encoder_layers)
|
||||
]
|
||||
conv_caches = [
|
||||
torch.zeros(params.d_model, params.cnn_module_kernel, device=device)
|
||||
for _ in range(params.num_encoder_layers)
|
||||
]
|
||||
self.states = [past_len, attn_caches, conv_caches]
|
||||
self.init_states()
|
||||
|
||||
# It use different attributes for different decoding methods.
|
||||
self.context_size = params.context_size
|
||||
self.decoding_method = params.decoding_method
|
||||
if params.decoding_method == "greedy_search":
|
||||
self.hyp: Optional[List[int]] = None
|
||||
self.decoder_out: Optional[torch.Tensor] = None
|
||||
self.hyp = [params.blank_id] * params.context_size
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
self.hyps = HypothesisList()
|
||||
else:
|
||||
@ -100,70 +75,86 @@ class Stream(object):
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
|
||||
self.sample_rate = params.sample_rate
|
||||
self.audio_sample = audio_sample
|
||||
# Current index of sample
|
||||
self.cur_index = 0
|
||||
self.ground_truth: str = ""
|
||||
|
||||
self.feature: torch.Tensor = None
|
||||
# Make sure all feature frames can be used.
|
||||
# Add 2 here since we will drop the first and last after subsampling.
|
||||
self.chunk_length = params.chunk_length
|
||||
self.pad_length = (
|
||||
params.right_context_length + 2 * params.subsampling_factor + 3
|
||||
)
|
||||
self.num_frames = 0
|
||||
self.num_processed_frames = 0
|
||||
|
||||
# After all feature frames are processed, we set this flag to True
|
||||
self._done = False
|
||||
|
||||
def set_feature(self, feature: torch.Tensor) -> None:
|
||||
assert feature.dim == 2, feature.dim
|
||||
self.num_frames = feature.size(0)
|
||||
# tail padding
|
||||
self.feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
(0, 0, 0, self.pad_length),
|
||||
mode="constant",
|
||||
value=math.log(1e-10),
|
||||
)
|
||||
|
||||
def set_ground_truth(self, ground_truth: str) -> None:
|
||||
self.ground_truth = ground_truth
|
||||
|
||||
def accept_waveform(
|
||||
self,
|
||||
# sampling_rate: float,
|
||||
# waveform: torch.Tensor,
|
||||
) -> None:
|
||||
"""Feed audio samples to the feature extractor and compute features
|
||||
if there are enough samples available.
|
||||
def init_states(self, params: AttributeDict) -> None:
|
||||
attn_caches = [
|
||||
[
|
||||
torch.zeros(
|
||||
params.memory_size, params.encoder_dim, device=self.device
|
||||
),
|
||||
torch.zeros(
|
||||
params.left_context_length // params.subsampling_factor,
|
||||
params.encoder_dim,
|
||||
device=self.device,
|
||||
),
|
||||
torch.zeros(
|
||||
params.left_context_length // params.subsampling_factor,
|
||||
params.encoder_dim,
|
||||
device=self.device,
|
||||
),
|
||||
]
|
||||
for _ in range(params.num_encoder_layers)
|
||||
]
|
||||
conv_caches = [
|
||||
torch.zeros(
|
||||
params.encoder_dim,
|
||||
params.cnn_module_kernel - 1,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(params.num_encoder_layers)
|
||||
]
|
||||
self.states = [attn_caches, conv_caches]
|
||||
|
||||
Caution:
|
||||
The range of the audio samples should match the one used in the
|
||||
training. That is, if you use the range [-1, 1] in the training, then
|
||||
the input audio samples should also be normalized to [-1, 1].
|
||||
|
||||
Args
|
||||
sampling_rate:
|
||||
The sampling rate of the input audio samples. It is used for sanity
|
||||
check to ensure that the input sampling rate equals to the one
|
||||
used in the extractor. If they are not equal, then no resampling
|
||||
will be performed; instead an error will be thrown.
|
||||
waveform:
|
||||
A 1-D torch tensor of dtype torch.float32 containing audio samples.
|
||||
It should be on CPU.
|
||||
"""
|
||||
start = self.cur_index
|
||||
end = self.cur_index + 1024
|
||||
waveform = self.audio_sample[start:end]
|
||||
self.cur_index = end
|
||||
|
||||
self.feature_extractor.accept_waveform(
|
||||
sampling_rate=self.sampling_rate,
|
||||
waveform=waveform,
|
||||
def get_feature_chunk(self) -> Tuple[torch.Tensor, int]:
|
||||
"""Get a chunk of feature frames."""
|
||||
update_length = min(
|
||||
self.num_frames - self.num_processed_frames, self.chunk_length
|
||||
)
|
||||
self._fetch_frames()
|
||||
ret_length = update_length + self.pad_length
|
||||
|
||||
if waveform.numel() == 0:
|
||||
self.input_finished()
|
||||
ret_feature = self.feature[:ret_length]
|
||||
# Cut off used frames.
|
||||
self.feature = self.feature[update_length:]
|
||||
|
||||
def input_finished(self) -> None:
|
||||
"""Signal that no more audio samples available and the feature
|
||||
extractor should flush the buffered samples to compute frames.
|
||||
"""
|
||||
self.feature_extractor.input_finished()
|
||||
self._fetch_frames()
|
||||
self._done = True
|
||||
self.num_processed_frames += update_length
|
||||
if self.num_processed_frames >= self.num_frames:
|
||||
self._done = True
|
||||
|
||||
return ret_feature, ret_length
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""Return True if `self.input_finished()` has been invoked"""
|
||||
return self._done
|
||||
|
||||
def _fetch_frames(self) -> None:
|
||||
"""Fetch frames from the feature extractor"""
|
||||
while self.num_fetched_frames < self.feature_extractor.num_frames_ready:
|
||||
frame = self.feature_extractor.get_frame(self.num_fetched_frames)
|
||||
self.feature_frames.append(frame)
|
||||
self.num_fetched_frames += 1
|
||||
|
||||
def decoding_result(self) -> List[int]:
|
||||
"""Obtain current decoding result."""
|
||||
if self.decoding_method == "greedy_search":
|
||||
|
Loading…
x
Reference in New Issue
Block a user