diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_feature_extractor.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_feature_extractor.py index b89c6acdd..b7293cac6 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_feature_extractor.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_feature_extractor.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +import math +from typing import List, Optional, Tuple import torch from beam_search import HypothesisList @@ -46,53 +47,27 @@ class Stream(object): def __init__( self, params: AttributeDict, - audio_sample: torch.Tensor, - ground_truth: str, - device: torch.device = torch.devive("cpu"), + device: torch.device = torch.device("cpu"), ) -> None: """ Args: - context_size: - Context size of the RNN-T decoder model. - decoding_method: - Decoding method. The possible values are: - - greedy_search - - modified_beam_search + params: + It's the return value of :func:`get_params`. + device: + The device to run this stream. """ - self.feature_extractor = _create_streaming_feature_extractor() - # It contains a list of 1-D tensors representing the feature frames. - self.feature_frames: List[torch.Tensor] = [] - self.num_fetched_frames = 0 - - # After calling `self.input_finished()`, we set this flag to True - self._done = False + self.device = device + # Containing attention caches and convolution caches + self.states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = None # Initailize zero states. - past_len: int = 0 - attn_caches = [ - [ - torch.zeros(params.memory_size, params.d_model, device=device), - torch.zeros( - params.left_context_length, params.d_model, device=device - ), - torch.zeros( - params.left_context_length, params.d_model, device=device - ), - ] - for _ in range(params.num_encoder_layers) - ] - conv_caches = [ - torch.zeros(params.d_model, params.cnn_module_kernel, device=device) - for _ in range(params.num_encoder_layers) - ] - self.states = [past_len, attn_caches, conv_caches] + self.init_states() # It use different attributes for different decoding methods. self.context_size = params.context_size self.decoding_method = params.decoding_method if params.decoding_method == "greedy_search": - self.hyp: Optional[List[int]] = None - self.decoder_out: Optional[torch.Tensor] = None + self.hyp = [params.blank_id] * params.context_size elif params.decoding_method == "modified_beam_search": self.hyps = HypothesisList() else: @@ -100,70 +75,86 @@ class Stream(object): f"Unsupported decoding method: {params.decoding_method}" ) - self.sample_rate = params.sample_rate - self.audio_sample = audio_sample - # Current index of sample - self.cur_index = 0 + self.ground_truth: str = "" + self.feature: torch.Tensor = None + # Make sure all feature frames can be used. + # Add 2 here since we will drop the first and last after subsampling. + self.chunk_length = params.chunk_length + self.pad_length = ( + params.right_context_length + 2 * params.subsampling_factor + 3 + ) + self.num_frames = 0 + self.num_processed_frames = 0 + + # After all feature frames are processed, we set this flag to True + self._done = False + + def set_feature(self, feature: torch.Tensor) -> None: + assert feature.dim == 2, feature.dim + self.num_frames = feature.size(0) + # tail padding + self.feature = torch.nn.functional.pad( + feature, + (0, 0, 0, self.pad_length), + mode="constant", + value=math.log(1e-10), + ) + + def set_ground_truth(self, ground_truth: str) -> None: self.ground_truth = ground_truth - def accept_waveform( - self, - # sampling_rate: float, - # waveform: torch.Tensor, - ) -> None: - """Feed audio samples to the feature extractor and compute features - if there are enough samples available. + def init_states(self, params: AttributeDict) -> None: + attn_caches = [ + [ + torch.zeros( + params.memory_size, params.encoder_dim, device=self.device + ), + torch.zeros( + params.left_context_length // params.subsampling_factor, + params.encoder_dim, + device=self.device, + ), + torch.zeros( + params.left_context_length // params.subsampling_factor, + params.encoder_dim, + device=self.device, + ), + ] + for _ in range(params.num_encoder_layers) + ] + conv_caches = [ + torch.zeros( + params.encoder_dim, + params.cnn_module_kernel - 1, + device=self.device, + ) + for _ in range(params.num_encoder_layers) + ] + self.states = [attn_caches, conv_caches] - Caution: - The range of the audio samples should match the one used in the - training. That is, if you use the range [-1, 1] in the training, then - the input audio samples should also be normalized to [-1, 1]. - - Args - sampling_rate: - The sampling rate of the input audio samples. It is used for sanity - check to ensure that the input sampling rate equals to the one - used in the extractor. If they are not equal, then no resampling - will be performed; instead an error will be thrown. - waveform: - A 1-D torch tensor of dtype torch.float32 containing audio samples. - It should be on CPU. - """ - start = self.cur_index - end = self.cur_index + 1024 - waveform = self.audio_sample[start:end] - self.cur_index = end - - self.feature_extractor.accept_waveform( - sampling_rate=self.sampling_rate, - waveform=waveform, + def get_feature_chunk(self) -> Tuple[torch.Tensor, int]: + """Get a chunk of feature frames.""" + update_length = min( + self.num_frames - self.num_processed_frames, self.chunk_length ) - self._fetch_frames() + ret_length = update_length + self.pad_length - if waveform.numel() == 0: - self.input_finished() + ret_feature = self.feature[:ret_length] + # Cut off used frames. + self.feature = self.feature[update_length:] - def input_finished(self) -> None: - """Signal that no more audio samples available and the feature - extractor should flush the buffered samples to compute frames. - """ - self.feature_extractor.input_finished() - self._fetch_frames() - self._done = True + self.num_processed_frames += update_length + if self.num_processed_frames >= self.num_frames: + self._done = True + + return ret_feature, ret_length @property def done(self) -> bool: """Return True if `self.input_finished()` has been invoked""" return self._done - def _fetch_frames(self) -> None: - """Fetch frames from the feature extractor""" - while self.num_fetched_frames < self.feature_extractor.num_frames_ready: - frame = self.feature_extractor.get_frame(self.num_fetched_frames) - self.feature_frames.append(frame) - self.num_fetched_frames += 1 - def decoding_result(self) -> List[int]: """Obtain current decoding result.""" if self.decoding_method == "greedy_search":