Minor fixes on decode stream
This commit is contained in:
parent
5f64396112
commit
5fe60dec43
@ -52,13 +52,17 @@ class DecodeStream(object):
|
|||||||
|
|
||||||
# It contains a 2-D tensors representing the feature frames.
|
# It contains a 2-D tensors representing the feature frames.
|
||||||
self.features: torch.Tensor = None
|
self.features: torch.Tensor = None
|
||||||
|
|
||||||
|
self.num_frames: int = 0
|
||||||
# how many frames have been processed. (before subsampling).
|
# how many frames have been processed. (before subsampling).
|
||||||
# we only modify this value in `func:get_feature_frames`.
|
# we only modify this value in `func:get_feature_frames`.
|
||||||
self.num_frames = 0
|
|
||||||
self.num_processed_frames: int = 0
|
self.num_processed_frames: int = 0
|
||||||
|
|
||||||
self._done: bool = False
|
self._done: bool = False
|
||||||
|
|
||||||
# The transcript of current utterance.
|
# The transcript of current utterance.
|
||||||
self.ground_truth: str = ""
|
self.ground_truth: str = ""
|
||||||
|
|
||||||
# The decoding result (partial or final) of current utterance.
|
# The decoding result (partial or final) of current utterance.
|
||||||
self.hyp: List = []
|
self.hyp: List = []
|
||||||
|
|
||||||
@ -94,30 +98,28 @@ class DecodeStream(object):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Set features tensor of current utterance."""
|
"""Set features tensor of current utterance."""
|
||||||
assert features.dim() == 2, features.dim()
|
assert features.dim() == 2, features.dim()
|
||||||
self.num_frames = features.size(0)
|
|
||||||
# tail padding
|
|
||||||
self.features = torch.nn.functional.pad(
|
self.features = torch.nn.functional.pad(
|
||||||
features,
|
features,
|
||||||
(0, 0, 0, self.pad_length),
|
(0, 0, 0, self.pad_length),
|
||||||
mode="constant",
|
mode="constant",
|
||||||
value=self.LOG_EPS,
|
value=self.LOG_EPS,
|
||||||
)
|
)
|
||||||
|
self.num_frames = self.features.size(0)
|
||||||
|
|
||||||
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
|
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
|
||||||
"""Consume chunk_size frames of features"""
|
"""Consume chunk_size frames of features"""
|
||||||
# plus 3 here because we subsampling features with
|
chunk_length = chunk_size + self.pad_length
|
||||||
# lengths = ((x_lens - 1) // 2 - 1) // 2
|
|
||||||
update_length = min(
|
ret_length = min(
|
||||||
self.num_frames - self.num_processed_frames, chunk_size
|
self.num_frames - self.num_processed_frames, chunk_length
|
||||||
)
|
)
|
||||||
ret_length = update_length + self.pad_length
|
|
||||||
|
|
||||||
ret_features = self.features[
|
ret_features = self.features[
|
||||||
self.num_processed_frames : self.num_processed_frames # noqa
|
self.num_processed_frames : self.num_processed_frames # noqa
|
||||||
+ ret_length
|
+ ret_length
|
||||||
]
|
]
|
||||||
|
|
||||||
self.num_processed_frames += update_length
|
self.num_processed_frames += chunk_size
|
||||||
if self.num_processed_frames >= self.num_frames:
|
if self.num_processed_frames >= self.num_frames:
|
||||||
self._done = True
|
self._done = True
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user