mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 09:34:39 +00:00
refactor Stream class
This commit is contained in:
parent
f8071e9373
commit
acc8a36b5e
@ -14,7 +14,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Optional
|
import math
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from beam_search import HypothesisList
|
from beam_search import HypothesisList
|
||||||
@ -46,53 +47,27 @@ class Stream(object):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
audio_sample: torch.Tensor,
|
device: torch.device = torch.device("cpu"),
|
||||||
ground_truth: str,
|
|
||||||
device: torch.device = torch.devive("cpu"),
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
context_size:
|
params:
|
||||||
Context size of the RNN-T decoder model.
|
It's the return value of :func:`get_params`.
|
||||||
decoding_method:
|
device:
|
||||||
Decoding method. The possible values are:
|
The device to run this stream.
|
||||||
- greedy_search
|
|
||||||
- modified_beam_search
|
|
||||||
"""
|
"""
|
||||||
self.feature_extractor = _create_streaming_feature_extractor()
|
self.device = device
|
||||||
# It contains a list of 1-D tensors representing the feature frames.
|
|
||||||
self.feature_frames: List[torch.Tensor] = []
|
|
||||||
self.num_fetched_frames = 0
|
|
||||||
|
|
||||||
# After calling `self.input_finished()`, we set this flag to True
|
|
||||||
self._done = False
|
|
||||||
|
|
||||||
|
# Containing attention caches and convolution caches
|
||||||
|
self.states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = None
|
||||||
# Initailize zero states.
|
# Initailize zero states.
|
||||||
past_len: int = 0
|
self.init_states()
|
||||||
attn_caches = [
|
|
||||||
[
|
|
||||||
torch.zeros(params.memory_size, params.d_model, device=device),
|
|
||||||
torch.zeros(
|
|
||||||
params.left_context_length, params.d_model, device=device
|
|
||||||
),
|
|
||||||
torch.zeros(
|
|
||||||
params.left_context_length, params.d_model, device=device
|
|
||||||
),
|
|
||||||
]
|
|
||||||
for _ in range(params.num_encoder_layers)
|
|
||||||
]
|
|
||||||
conv_caches = [
|
|
||||||
torch.zeros(params.d_model, params.cnn_module_kernel, device=device)
|
|
||||||
for _ in range(params.num_encoder_layers)
|
|
||||||
]
|
|
||||||
self.states = [past_len, attn_caches, conv_caches]
|
|
||||||
|
|
||||||
# It use different attributes for different decoding methods.
|
# It use different attributes for different decoding methods.
|
||||||
self.context_size = params.context_size
|
self.context_size = params.context_size
|
||||||
self.decoding_method = params.decoding_method
|
self.decoding_method = params.decoding_method
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
self.hyp: Optional[List[int]] = None
|
self.hyp = [params.blank_id] * params.context_size
|
||||||
self.decoder_out: Optional[torch.Tensor] = None
|
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
self.hyps = HypothesisList()
|
self.hyps = HypothesisList()
|
||||||
else:
|
else:
|
||||||
@ -100,70 +75,86 @@ class Stream(object):
|
|||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.sample_rate = params.sample_rate
|
self.ground_truth: str = ""
|
||||||
self.audio_sample = audio_sample
|
|
||||||
# Current index of sample
|
|
||||||
self.cur_index = 0
|
|
||||||
|
|
||||||
|
self.feature: torch.Tensor = None
|
||||||
|
# Make sure all feature frames can be used.
|
||||||
|
# Add 2 here since we will drop the first and last after subsampling.
|
||||||
|
self.chunk_length = params.chunk_length
|
||||||
|
self.pad_length = (
|
||||||
|
params.right_context_length + 2 * params.subsampling_factor + 3
|
||||||
|
)
|
||||||
|
self.num_frames = 0
|
||||||
|
self.num_processed_frames = 0
|
||||||
|
|
||||||
|
# After all feature frames are processed, we set this flag to True
|
||||||
|
self._done = False
|
||||||
|
|
||||||
|
def set_feature(self, feature: torch.Tensor) -> None:
|
||||||
|
assert feature.dim == 2, feature.dim
|
||||||
|
self.num_frames = feature.size(0)
|
||||||
|
# tail padding
|
||||||
|
self.feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
(0, 0, 0, self.pad_length),
|
||||||
|
mode="constant",
|
||||||
|
value=math.log(1e-10),
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_ground_truth(self, ground_truth: str) -> None:
|
||||||
self.ground_truth = ground_truth
|
self.ground_truth = ground_truth
|
||||||
|
|
||||||
def accept_waveform(
|
def init_states(self, params: AttributeDict) -> None:
|
||||||
self,
|
attn_caches = [
|
||||||
# sampling_rate: float,
|
[
|
||||||
# waveform: torch.Tensor,
|
torch.zeros(
|
||||||
) -> None:
|
params.memory_size, params.encoder_dim, device=self.device
|
||||||
"""Feed audio samples to the feature extractor and compute features
|
),
|
||||||
if there are enough samples available.
|
torch.zeros(
|
||||||
|
params.left_context_length // params.subsampling_factor,
|
||||||
|
params.encoder_dim,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
torch.zeros(
|
||||||
|
params.left_context_length // params.subsampling_factor,
|
||||||
|
params.encoder_dim,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for _ in range(params.num_encoder_layers)
|
||||||
|
]
|
||||||
|
conv_caches = [
|
||||||
|
torch.zeros(
|
||||||
|
params.encoder_dim,
|
||||||
|
params.cnn_module_kernel - 1,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
for _ in range(params.num_encoder_layers)
|
||||||
|
]
|
||||||
|
self.states = [attn_caches, conv_caches]
|
||||||
|
|
||||||
Caution:
|
def get_feature_chunk(self) -> Tuple[torch.Tensor, int]:
|
||||||
The range of the audio samples should match the one used in the
|
"""Get a chunk of feature frames."""
|
||||||
training. That is, if you use the range [-1, 1] in the training, then
|
update_length = min(
|
||||||
the input audio samples should also be normalized to [-1, 1].
|
self.num_frames - self.num_processed_frames, self.chunk_length
|
||||||
|
|
||||||
Args
|
|
||||||
sampling_rate:
|
|
||||||
The sampling rate of the input audio samples. It is used for sanity
|
|
||||||
check to ensure that the input sampling rate equals to the one
|
|
||||||
used in the extractor. If they are not equal, then no resampling
|
|
||||||
will be performed; instead an error will be thrown.
|
|
||||||
waveform:
|
|
||||||
A 1-D torch tensor of dtype torch.float32 containing audio samples.
|
|
||||||
It should be on CPU.
|
|
||||||
"""
|
|
||||||
start = self.cur_index
|
|
||||||
end = self.cur_index + 1024
|
|
||||||
waveform = self.audio_sample[start:end]
|
|
||||||
self.cur_index = end
|
|
||||||
|
|
||||||
self.feature_extractor.accept_waveform(
|
|
||||||
sampling_rate=self.sampling_rate,
|
|
||||||
waveform=waveform,
|
|
||||||
)
|
)
|
||||||
self._fetch_frames()
|
ret_length = update_length + self.pad_length
|
||||||
|
|
||||||
if waveform.numel() == 0:
|
ret_feature = self.feature[:ret_length]
|
||||||
self.input_finished()
|
# Cut off used frames.
|
||||||
|
self.feature = self.feature[update_length:]
|
||||||
|
|
||||||
def input_finished(self) -> None:
|
self.num_processed_frames += update_length
|
||||||
"""Signal that no more audio samples available and the feature
|
if self.num_processed_frames >= self.num_frames:
|
||||||
extractor should flush the buffered samples to compute frames.
|
self._done = True
|
||||||
"""
|
|
||||||
self.feature_extractor.input_finished()
|
return ret_feature, ret_length
|
||||||
self._fetch_frames()
|
|
||||||
self._done = True
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def done(self) -> bool:
|
def done(self) -> bool:
|
||||||
"""Return True if `self.input_finished()` has been invoked"""
|
"""Return True if `self.input_finished()` has been invoked"""
|
||||||
return self._done
|
return self._done
|
||||||
|
|
||||||
def _fetch_frames(self) -> None:
|
|
||||||
"""Fetch frames from the feature extractor"""
|
|
||||||
while self.num_fetched_frames < self.feature_extractor.num_frames_ready:
|
|
||||||
frame = self.feature_extractor.get_frame(self.num_fetched_frames)
|
|
||||||
self.feature_frames.append(frame)
|
|
||||||
self.num_fetched_frames += 1
|
|
||||||
|
|
||||||
def decoding_result(self) -> List[int]:
|
def decoding_result(self) -> List[int]:
|
||||||
"""Obtain current decoding result."""
|
"""Obtain current decoding result."""
|
||||||
if self.decoding_method == "greedy_search":
|
if self.decoding_method == "greedy_search":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user