mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Minor fixes on decode stream
This commit is contained in:
parent
5f64396112
commit
5fe60dec43
@ -52,13 +52,17 @@ class DecodeStream(object):
|
||||
|
||||
# It contains a 2-D tensors representing the feature frames.
|
||||
self.features: torch.Tensor = None
|
||||
|
||||
self.num_frames: int = 0
|
||||
# how many frames have been processed. (before subsampling).
|
||||
# we only modify this value in `func:get_feature_frames`.
|
||||
self.num_frames = 0
|
||||
self.num_processed_frames: int = 0
|
||||
|
||||
self._done: bool = False
|
||||
|
||||
# The transcript of current utterance.
|
||||
self.ground_truth: str = ""
|
||||
|
||||
# The decoding result (partial or final) of current utterance.
|
||||
self.hyp: List = []
|
||||
|
||||
@ -94,30 +98,28 @@ class DecodeStream(object):
|
||||
) -> None:
|
||||
"""Set features tensor of current utterance."""
|
||||
assert features.dim() == 2, features.dim()
|
||||
self.num_frames = features.size(0)
|
||||
# tail padding
|
||||
self.features = torch.nn.functional.pad(
|
||||
features,
|
||||
(0, 0, 0, self.pad_length),
|
||||
mode="constant",
|
||||
value=self.LOG_EPS,
|
||||
)
|
||||
self.num_frames = self.features.size(0)
|
||||
|
||||
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
|
||||
"""Consume chunk_size frames of features"""
|
||||
# plus 3 here because we subsampling features with
|
||||
# lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||
update_length = min(
|
||||
self.num_frames - self.num_processed_frames, chunk_size
|
||||
chunk_length = chunk_size + self.pad_length
|
||||
|
||||
ret_length = min(
|
||||
self.num_frames - self.num_processed_frames, chunk_length
|
||||
)
|
||||
ret_length = update_length + self.pad_length
|
||||
|
||||
ret_features = self.features[
|
||||
self.num_processed_frames : self.num_processed_frames # noqa
|
||||
+ ret_length
|
||||
]
|
||||
|
||||
self.num_processed_frames += update_length
|
||||
self.num_processed_frames += chunk_size
|
||||
if self.num_processed_frames >= self.num_frames:
|
||||
self._done = True
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user