diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index bba884120..956278546 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -37,6 +37,112 @@ from scaling import ( from icefall.utils import make_pad_mask +LOG_EPSILON = math.log(1e-10) + + +def unstack_states( + states, +) -> List[List[List[torch.Tensor]]]: + # TODO: modify doc + """Unstack the emformer state corresponding to a batch of utterances + into a list of states, were the i-th entry is the state from the i-th + utterance in the batch. + + Args: + states: + A list-of-list of tensors. ``len(states)`` equals to number of + layers in the emformer. ``states[i]]`` contains the states for + the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape + ``(T, N, C)`` or a 2-D tensor of shape ``(C, N)`` + """ + + past_lens, attn_caches, conv_caches = states + batch_size = past_lens.size(0) + num_layers = len(attn_caches) + + list_past_len = past_lens.tolist() + + list_attn_caches = [None] * batch_size + for i in range(batch_size): + list_attn_caches[i] = [[] for _ in range(num_layers)] + for li, layer in enumerate(attn_caches): + for s in layer: + s_list = s.unbind(dim=1) + for bi, b in enumerate(list_attn_caches): + b[li].append(s_list[bi]) + + list_conv_caches = [None] * batch_size + for i in range(batch_size): + list_conv_caches[i] = [None] * num_layers + for li, layer in enumerate(conv_caches): + c_list = layer.unbind(dim=0) + for bi, b in enumerate(list_conv_caches): + b[li] = c_list[bi] + + ans = [None] * batch_size + for i in range(batch_size): + ans[i] = [list_past_len[i], list_attn_caches[i], list_conv_caches[i]] + + return ans + + +def stack_states( + state_list, +) -> List[List[torch.Tensor]]: + # TODO: modify doc + """Stack list of emformer states that correspond to separate utterances + into a single emformer state so that it can be used as an input for + emformer when those utterances are formed into a batch. + + Note: + It is the inverse of :func:`unstack_states`. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the emformer model for a single utterance. + Returns: + Return a new state corresponding to a batch of utterances. + See the input argument of :func:`unstack_states` for the meaning + of the returned tensor. + """ + batch_size = len(state_list) + + past_lens = [states[0] for states in state_list] + past_lens = torch.tensor([past_lens]) + + attn_caches = [] + for layer in state_list[0][1]: + if batch_size > 1: + # Note: We will stack attn_caches[layer][s][] later to get attn_caches[layer][s] # noqa + attn_caches.append([[s] for s in layer]) + else: + attn_caches.append([s.unsqueeze(1) for s in layer]) + for b, states in enumerate(state_list[1:], 1): + for li, layer in enumerate(states[1]): + for si, s in enumerate(layer): + attn_caches[li][si].append(s) + if b == batch_size - 1: + attn_caches[li][si] = torch.stack( + attn_caches[li][si], dim=1 + ) + + conv_caches = [] + for layer in state_list[0][2]: + if batch_size > 1: + # Note: We will stack conv_caches[layer][] later to get attn_caches[layer] # noqa + conv_caches.append([layer]) + else: + conv_caches.append(layer.unsqueeze(0)) + for b, states in enumerate(state_list[1:], 1): + for li, layer in enumerate(states[2]): + conv_caches[li].append(layer) + if b == batch_size - 1: + conv_caches[li] = torch.stack(conv_caches[li], dim=0) + + return [past_lens, attn_caches, conv_caches] + + class ConvolutionModule(nn.Module): """ConvolutionModule. diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 2064bd344..80c54d384 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import List, Optional, Tuple import k2 +from lhotse import CutSet import numpy as np import sentencepiece as spm import torch @@ -30,7 +31,7 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import Hypothesis, HypothesisList, get_hyps_shape from emformer import LOG_EPSILON, stack_states, unstack_states -from streaming_feature_extractor import FeatureExtractionStream +from streaming_feature_extractor import Stream from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -157,174 +158,21 @@ def get_parser(): help="Sample rate of the audio", ) + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel", + ) + add_model_arguments(parser) return parser -class StreamingAudioSamples(object): - """This class takes as input a list of audio samples and returns - them in a streaming fashion. - """ - - def __init__(self, samples: List[torch.Tensor]) -> None: - """ - Args: - samples: - A list of audio samples. Each entry is a 1-D tensor of dtype - torch.float32, containing the audio samples of an utterance. - """ - self.samples = samples - self.cur_indexes = [0] * len(self.samples) - - @property - def done(self) -> bool: - """Return True if all samples have been processed. - Return False otherwise. - """ - for i, samples in zip(self.cur_indexes, self.samples): - if i < samples.numel(): - return False - return True - - def get_next(self) -> List[torch.Tensor]: - """Return a list of audio samples. Each entry may have different - lengths. It is OK if an entry contains no samples at all, which - means it reaches the end of the utterance. - """ - ans = [] - - num = [1024] * len(self.samples) - - for i in range(len(self.samples)): - start = self.cur_indexes[i] - end = start + num[i] - self.cur_indexes[i] = end - - s = self.samples[i][start:end] - ans.append(s) - - return ans - - -class StreamList(object): - def __init__( - self, - batch_size: int, - context_size: int, - decoding_method: str, - ): - """ - Args: - batch_size: - Size of this batch. - context_size: - Context size of the RNN-T decoder model. - decoding_method: - Decoding method. The possible values are: - - greedy_search - - modified_beam_search - """ - - self.streams = [ - FeatureExtractionStream( - context_size=context_size, decoding_method=decoding_method - ) - for _ in range(batch_size) - ] - - def __getitem__(self, i) -> FeatureExtractionStream: - return self.streams[i] - - @property - def done(self) -> bool: - """Return True if all streams have reached end of utterance. - That is, no more audio samples are available for all utterances. - """ - return all(stream.done for stream in self.streams) - - def accept_waveform( - self, - audio_samples: List[torch.Tensor], - sampling_rate: float, - ): - """Feed audio samples to each stream. - Args: - audio_samples: - A list of 1-D tensors containing the audio samples for each - utterance in the batch. If an entry is empty, it means - end-of-utterance has been reached. - sampling_rate: - Sampling rate of the given audio samples. - """ - assert len(audio_samples) == len(self.streams) - for stream, samples in zip(self.streams, audio_samples): - - if stream.done: - assert samples.numel() == 0 - continue - - stream.accept_waveform( - sampling_rate=sampling_rate, - waveform=samples, - ) - - if samples.numel() == 0: - stream.input_finished() - - def build_batch( - self, - chunk_length: int, - segment_length: int, - ) -> Tuple[Optional[torch.Tensor], Optional[List[FeatureExtractionStream]]]: - """ - Args: - chunk_length: - Number of frames for each chunk. It equals to - ``segment_length + right_context_length``. - segment_length - Number of frames for each segment. - Returns: - Return a tuple containing: - - features, a 3-D tensor of shape ``(num_active_streams, T, C)`` - - active_streams, a list of active streams. We say a stream is - active when it has enough feature frames to be fed into the - encoder model. - """ - feature_list = [] - stream_list = [] - for stream in self.streams: - if len(stream.feature_frames) >= chunk_length: - # this_chunk is a list of tensors, each of which - # has a shape (1, feature_dim) - chunk = stream.feature_frames[:chunk_length] - stream.feature_frames = stream.feature_frames[segment_length:] - features = torch.cat(chunk, dim=0) - feature_list.append(features) - stream_list.append(stream) - elif stream.done and len(stream.feature_frames) > 0: - chunk = stream.feature_frames[:chunk_length] - stream.feature_frames = [] - features = torch.cat(chunk, dim=0) - features = torch.nn.functional.pad( - features, - (0, 0, 0, chunk_length - features.size(0)), - mode="constant", - value=LOG_EPSILON, - ) - feature_list.append(features) - stream_list.append(stream) - - if len(feature_list) == 0: - return None, None - - features = torch.stack(feature_list, dim=0) - return features, stream_list - - def greedy_search( model: nn.Module, - streams: List[FeatureExtractionStream], + streams: List[Stream], encoder_out: torch.Tensor, sp: spm.SentencePieceProcessor, ): @@ -401,7 +249,7 @@ def greedy_search( def modified_beam_search( model: nn.Module, - streams: List[FeatureExtractionStream], + streams: List[Stream], encoder_out: torch.Tensor, sp: spm.SentencePieceProcessor, beam: int = 4, @@ -513,10 +361,69 @@ def modified_beam_search( logging.info(f"Partial result {i}:\n{result}") +def build_batch( + decode_steams: List[Stream], + chunk_length: int, + segment_length: int, +) -> Tuple[ + Optional[torch.Tensor], + Optional[torch.tensor], + Optional[List[Stream]], +]: + """ + Args: + chunk_length: + Number of frames for each chunk. It equals to + ``segment_length + right_context_length``. + segment_length + Number of frames for each segment. + Returns: + Return a tuple containing: + - features, a 3-D tensor of shape ``(num_active_streams, T, C)`` + - active_streams, a list of active streams. We say a stream is + active when it has enough feature frames to be fed into the + encoder model. + """ + feature_list = [] + length_list = [] + stream_list = [] + for stream in decode_steams: + if len(stream.feature_frames) >= chunk_length: + # this_chunk is a list of tensors, each of which + # has a shape (1, feature_dim) + chunk = stream.feature_frames[:chunk_length] + stream.feature_frames = stream.feature_frames[segment_length:] + features = torch.cat(chunk, dim=0) + feature_list.append(features) + length_list.append(chunk_length) + stream_list.append(stream) + elif stream.done and len(stream.feature_frames) > 0: + chunk = stream.feature_frames[:chunk_length] + stream.feature_frames = [] + features = torch.cat(chunk, dim=0) + length_list.append(features.size(0)) + features = torch.nn.functional.pad( + features, + (0, 0, 0, chunk_length - features.size(0)), + mode="constant", + value=LOG_EPSILON, + ) + feature_list.append(features) + stream_list.append(stream) + + if len(feature_list) == 0: + return None, None, None + + features = torch.stack(feature_list, dim=0) + lengths = torch.cat(length_list) + return features, lengths, stream_list + + def process_features( model: nn.Module, features: torch.Tensor, - streams: List[FeatureExtractionStream], + feature_lens: torch.Tensor, + streams: List[Stream], params: AttributeDict, sp: spm.SentencePieceProcessor, ) -> None: @@ -536,30 +443,20 @@ def process_features( """ assert features.ndim == 3 assert features.size(0) == len(streams) - batch_size = features.size(0) + assert feature_lens.size(0) == len(streams) device = model.device features = features.to(device) - feature_lens = torch.full( - (batch_size,), - fill_value=features.size(1), - device=device, - ) - # Caution: It has a limitation as it assumes that - # if one of the stream has an empty state, then all other - # streams also have empty states. - if streams[0].states is None: - states = None - else: - state_list = [stream.states for stream in streams] - states = stack_states(state_list) + state_list = [stream.states for stream in streams] + states = stack_states(state_list) - (encoder_out, encoder_out_lens, states,) = model.encoder.streaming_forward( + encoder_out, encoder_out_lens, states = model.encoder.infer( features, feature_lens, states, ) + state_list = unstack_states(states) for i, s in enumerate(state_list): streams[i].states = s @@ -585,64 +482,80 @@ def process_features( ) -def decode_batch( - batched_samples: List[torch.Tensor], - model: nn.Module, +def decode_dataset( params: AttributeDict, + cuts: CutSet, + model: nn.Module, sp: spm.SentencePieceProcessor, -) -> List[str]: - """ +): + """Decode dataset. Args: - batched_samples: - A list of 1-D tensors containing the audio samples of each utterance. - model: - The RNN-T model. - params: - It is the return value of :func:`get_params`. - sp: - The BPE model. """ + device = next(model.parameters()).device + # number of frames before subsampling segment_length = model.encoder.segment_length - right_context_length = model.encoder.right_context_length + # 5 = 3 + 2 + # 1) add 3 here since the subsampling method is using + # ((len - 1) // 2 - 1) // 2) + # 2) add 2 here we will drop first and last frame after subsampling + chunk_length = (segment_length + 5) + right_context_length - # We add 3 here since the subsampling method is using - # ((len - 1) // 2 - 1) // 2) - chunk_length = (segment_length + 3) + right_context_length + decode_results = [] + streams = [] + for num, cut in enumerate(cuts): + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype - batch_size = len(batched_samples) - streaming_audio_samples = StreamingAudioSamples(batched_samples) + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" - stream_list = StreamList( - batch_size=batch_size, - context_size=params.context_size, - decoding_method=params.decoding_method, - ) + samples = torch.from_numpy(audio).squeeze(0) - while not streaming_audio_samples.done: - samples = streaming_audio_samples.get_next() - stream_list.accept_waveform( - audio_samples=samples, - sampling_rate=params.sampling_rate, + # Each uttetance has a Stream + stream = Stream( + params=params, + audio_sample=samples, + ground_truth=cut.supervisions[0].text, + device=device, ) - features, active_streams = stream_list.build_batch( - chunk_length=chunk_length, - segment_length=segment_length, - ) - if features is not None: - process_features( - model=model, - features=features, - streams=active_streams, - params=params, - sp=sp, + streams.append(stream) + + while len(streams) >= params.num_decode_streams: + for stream in streams: + stream.accept_waveform() + + # try to build batch + features, active_streams = build_batch( + chunk_length=chunk_length, + segment_length=segment_length, ) - results = [] - for stream in stream_list.streams: - text = sp.decode(stream.decoding_result()) - results.append(text) - return results + if features is not None: + process_features( + model=model, + features=features, + streams=active_streams, + params=params, + sp=sp, + ) + + new_streams = [] + for stream in streams: + if stream.done: + decode_results.append( + ( + stream.ground_truth.split(), + sp.decode(stream.decoding_result()).split(), + ) + ) + else: + new_streams.append(stream) + del streams + streams = new_streams @torch.no_grad() @@ -726,8 +639,8 @@ def main(): samples = torch.from_numpy(audio).squeeze(0) - batched_samples.append(samples) - ground_truth.append(cut.supervisions[0].text) + # batched_samples.append(samples) + # ground_truth.append(cut.supervisions[0].text) if len(batched_samples) >= batch_size: decoded_results = decode_batch( diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_feature_extractor.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_feature_extractor.py index 1392d9ae0..b89c6acdd 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_feature_extractor.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_feature_extractor.py @@ -42,12 +42,12 @@ def _create_streaming_feature_extractor() -> OnlineFeature: return OnlineFbank(opts) -class FeatureExtractionStream(object): +class Stream(object): def __init__( self, params: AttributeDict, - context_size: int, - decoding_method: str, + audio_sample: torch.Tensor, + ground_truth: str, device: torch.device = torch.devive("cpu"), ) -> None: """ @@ -63,6 +63,7 @@ class FeatureExtractionStream(object): # It contains a list of 1-D tensors representing the feature frames. self.feature_frames: List[torch.Tensor] = [] self.num_fetched_frames = 0 + # After calling `self.input_finished()`, we set this flag to True self._done = False @@ -87,20 +88,29 @@ class FeatureExtractionStream(object): self.states = [past_len, attn_caches, conv_caches] # It use different attributes for different decoding methods. - self.context_size = context_size - self.decoding_method = decoding_method - if decoding_method == "greedy_search": + self.context_size = params.context_size + self.decoding_method = params.decoding_method + if params.decoding_method == "greedy_search": self.hyp: Optional[List[int]] = None self.decoder_out: Optional[torch.Tensor] = None - elif decoding_method == "modified_beam_search": + elif params.decoding_method == "modified_beam_search": self.hyps = HypothesisList() else: - raise ValueError(f"Unsupported decoding method: {decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + + self.sample_rate = params.sample_rate + self.audio_sample = audio_sample + # Current index of sample + self.cur_index = 0 + + self.ground_truth = ground_truth def accept_waveform( self, - sampling_rate: float, - waveform: torch.Tensor, + # sampling_rate: float, + # waveform: torch.Tensor, ) -> None: """Feed audio samples to the feature extractor and compute features if there are enough samples available. @@ -120,12 +130,20 @@ class FeatureExtractionStream(object): A 1-D torch tensor of dtype torch.float32 containing audio samples. It should be on CPU. """ + start = self.cur_index + end = self.cur_index + 1024 + waveform = self.audio_sample[start:end] + self.cur_index = end + self.feature_extractor.accept_waveform( - sampling_rate=sampling_rate, + sampling_rate=self.sampling_rate, waveform=waveform, ) self._fetch_frames() + if waveform.numel() == 0: + self.input_finished() + def input_finished(self) -> None: """Signal that no more audio samples available and the feature extractor should flush the buffered samples to compute frames.