Minor fixes on decode stream

This commit is contained in:
pkufool 2022-06-20 20:01:47 +08:00
parent 5f64396112
commit 5fe60dec43

View File

@ -52,13 +52,17 @@ class DecodeStream(object):
# It contains a 2-D tensors representing the feature frames.
self.features: torch.Tensor = None
self.num_frames: int = 0
# how many frames have been processed. (before subsampling).
# we only modify this value in `func:get_feature_frames`.
self.num_frames = 0
self.num_processed_frames: int = 0
self._done: bool = False
# The transcript of current utterance.
self.ground_truth: str = ""
# The decoding result (partial or final) of current utterance.
self.hyp: List = []
@ -94,30 +98,28 @@ class DecodeStream(object):
) -> None:
"""Set features tensor of current utterance."""
assert features.dim() == 2, features.dim()
self.num_frames = features.size(0)
# tail padding
self.features = torch.nn.functional.pad(
features,
(0, 0, 0, self.pad_length),
mode="constant",
value=self.LOG_EPS,
)
self.num_frames = self.features.size(0)
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
"""Consume chunk_size frames of features"""
# plus 3 here because we subsampling features with
# lengths = ((x_lens - 1) // 2 - 1) // 2
update_length = min(
self.num_frames - self.num_processed_frames, chunk_size
chunk_length = chunk_size + self.pad_length
ret_length = min(
self.num_frames - self.num_processed_frames, chunk_length
)
ret_length = update_length + self.pad_length
ret_features = self.features[
self.num_processed_frames : self.num_processed_frames # noqa
+ ret_length
]
self.num_processed_frames += update_length
self.num_processed_frames += chunk_size
if self.num_processed_frames >= self.num_frames:
self._done = True