From 5fe60dec433e75564efac13824ca4359b6ffda6d Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 20 Jun 2022 20:01:47 +0800 Subject: [PATCH] Minor fixes on decode stream --- .../decode_stream.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index e96277ea7..ba5e80555 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -52,13 +52,17 @@ class DecodeStream(object): # It contains a 2-D tensors representing the feature frames. self.features: torch.Tensor = None + + self.num_frames: int = 0 # how many frames have been processed. (before subsampling). # we only modify this value in `func:get_feature_frames`. - self.num_frames = 0 self.num_processed_frames: int = 0 + self._done: bool = False + # The transcript of current utterance. self.ground_truth: str = "" + # The decoding result (partial or final) of current utterance. self.hyp: List = [] @@ -94,30 +98,28 @@ class DecodeStream(object): ) -> None: """Set features tensor of current utterance.""" assert features.dim() == 2, features.dim() - self.num_frames = features.size(0) - # tail padding self.features = torch.nn.functional.pad( features, (0, 0, 0, self.pad_length), mode="constant", value=self.LOG_EPS, ) + self.num_frames = self.features.size(0) def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: """Consume chunk_size frames of features""" - # plus 3 here because we subsampling features with - # lengths = ((x_lens - 1) // 2 - 1) // 2 - update_length = min( - self.num_frames - self.num_processed_frames, chunk_size + chunk_length = chunk_size + self.pad_length + + ret_length = min( + self.num_frames - self.num_processed_frames, chunk_length ) - ret_length = update_length + self.pad_length ret_features = self.features[ self.num_processed_frames : self.num_processed_frames # noqa + ret_length ] - self.num_processed_frames += update_length + self.num_processed_frames += chunk_size if self.num_processed_frames >= self.num_frames: self._done = True