diff --git a/egs/librispeech/ASR/transducer_emformer/emformer.py b/egs/librispeech/ASR/transducer_emformer/emformer.py index 631ff43fb..80849afd6 100644 --- a/egs/librispeech/ASR/transducer_emformer/emformer.py +++ b/egs/librispeech/ASR/transducer_emformer/emformer.py @@ -27,6 +27,66 @@ from torchaudio.models import Emformer as _Emformer LOG_EPSILON = math.log(1e-10) +def unstack_states( + states: List[List[torch.Tensor]], +) -> List[List[List[torch.Tensor]]]: + """Unstack the emformer state corresponding to a batch of utterances + into a list of states, were the i-th entry is the state from the i-th + utterance in the batch. + + Args: + states: + A list-of-list of tensors. ``len(states)`` equals to number of + layers in the emformer. ``states[i]]`` contains the states for + the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape + ``(T, N, C)`` or a 2-D tensor of shape ``(C, N)`` + """ + batch_size = states[0][0].size(1) + num_layers = len(states) + + ans = [None] * batch_size + for i in range(batch_size): + ans[i] = [[] for _ in range(num_layers)] + + for li, layer in enumerate(states): + for s in layer: + s_list = s.unbind(dim=1) + for bi, b in enumerate(ans): + b[li].append(s_list[bi].unsqueeze(dim=1)) + return ans + + +def stack_states( + state_list: List[List[List[torch.Tensor]]], +) -> List[List[torch.Tensor]]: + """Stack list of emformer states that correspond to separate utterances + into a single emformer state so that it can be used as an input for + emformer when those utterances are formed into a batch. + + Note: + It is the inverse of :func:`unstack_states`. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the emformer model for a single utterance. + Returns: + Return a new state corresponding to a batch of utterances. + See the input argument of :func:`unstack_states` for the meaning + of the returned tensor. + """ + ans = [] + for layer in state_list[0]: + # layer is a list of tensors + ans.append([s for s in layer]) + + for states in state_list[1:]: + for li, layer in enumerate(states): + for si, s in enumerate(layer): + ans[li][si] = torch.cat([ans[li][si], s], dim=1) + return ans + + class Emformer(EncoderInterface): """This is just a simple wrapper around torchaudio.models.Emformer. We may replace it with our own implementation some time later. diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py index aca6e444d..93ca43ff3 100755 --- a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py @@ -18,16 +18,16 @@ import argparse import logging -import time from pathlib import Path +from typing import List, Optional, Tuple import numpy as np import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from emformer import LOG_EPSILON -from streaming_feature_extractor import Stream +from emformer import LOG_EPSILON, stack_states, unstack_states +from streaming_feature_extractor import FeatureExtractionStream from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -158,9 +158,166 @@ def get_parser(): return parser +class StreamingAudioSamples(object): + """This class takes as input a list of audio samples and returns + them in a streaming fashion. + """ + + def __init__(self, samples: List[torch.Tensor]) -> None: + """ + Args: + samples: + A list of audio samples. Each entry is a 1-D tensor of dtype + torch.float32, containing the audio samples of an utterance. + """ + self.samples = samples + self.cur_indexes = [0] * len(self.samples) + + @property + def done(self) -> bool: + """Return True if all samples have been processed. + Return False otherwise. + """ + for i, samples in zip(self.cur_indexes, self.samples): + if i < samples.numel(): + return False + return True + + def get_next(self) -> List[torch.Tensor]: + """Return a list of audio samples. Each entry may have different + lengths. It is OK if an entry contains no samples at all, which + means it reaches the end of the utterance. + """ + ans = [] + + # Note: Either branch is fine. The purpose is to simulate streaming + if False: + num = torch.randint(2000, 5000, (len(self.samples),)).tolist() + else: + num = [1024] * len(self.samples) + + for i in range(len(self.samples)): + start = self.cur_indexes[i] + end = start + num[i] + self.cur_indexes[i] = end + + s = self.samples[i][start:end] + ans.append(s) + + return ans + + +class StreamList(object): + def __init__( + self, + batch_size: int, + context_size: int, + blank_id: int, + ): + """ + Args: + batch_size: + Size of this batch. + context_size: + Context size of the RNN-T decoder model. + blank_id: + The ID of the blank symbol of the BPE model. + """ + self.streams = [ + FeatureExtractionStream( + context_size=context_size, blank_id=blank_id + ) + for _ in range(batch_size) + ] + + @property + def done(self) -> bool: + """Return True if all streams have reached end of utterance. + That is, no more audio samples are available for all utterances. + """ + return all(stream.done for stream in self.streams) + + def accept_waveform( + self, + audio_samples: List[torch.Tensor], + sampling_rate: float, + ): + """Feeed audio samples to each stream. + Args: + audio_samples: + A list of 1-D tensors containing the audio samples for each + utterance in the batch. If an entry is empty, it means + end-of-utterance has been reached. + sampling_rate: + Sampling rate of the given audio samples. + """ + assert len(audio_samples) == len(self.streams) + for stream, samples in zip(self.streams, audio_samples): + + if stream.done: + assert samples.numel() == 0 + continue + + stream.accept_waveform( + sampling_rate=sampling_rate, + waveform=samples, + ) + + if samples.numel() == 0: + stream.input_finished() + + def build_batch( + self, + chunk_length: int, + segment_length: int, + ) -> Tuple[Optional[torch.Tensor], Optional[List[FeatureExtractionStream]]]: + """ + Args: + chunk_length: + Number of frames for each chunk. It equals to + ``segment_length + right_context_length``. + segment_length + Number of frames for each segment. + Returns: + Return a tuple containing: + - features, a 3-D tensor of shape ``(num_active_streams, T, C)`` + - active_streams, a list of active streams. We say a stream is + active when it has enough feature frames to be fed into the + encoder model. + """ + feature_list = [] + stream_list = [] + for stream in self.streams: + if len(stream.feature_frames) >= chunk_length: + # this_chunk is a list of tensors, each of which + # has a shape (1, feature_dim) + chunk = stream.feature_frames[:chunk_length] + stream.feature_frames = stream.feature_frames[segment_length:] + features = torch.cat(chunk, dim=0).unsqueeze(0) + feature_list.append(features) + stream_list.append(stream) + elif stream.done and len(stream.feature_frames) > 0: + chunk = stream.feature_frames[:chunk_length] + stream.feature_frames = [] + features = torch.cat(chunk, dim=0).unsqueeze(0) + features = torch.nn.functional.pad( + features, + (0, 0, 0, chunk_length - features.size(1)), + value=LOG_EPSILON, + ) + feature_list.append(features) + stream_list.append(stream) + + if len(feature_list) == 0: + return None, None + + features = torch.cat(feature_list, dim=0) + return features, stream_list + + def greedy_search( model: nn.Module, - stream: Stream, + streams: List[FeatureExtractionStream], encoder_out: torch.Tensor, sp: spm.SentencePieceProcessor, ): @@ -171,7 +328,7 @@ def greedy_search( stream: A stream object. encoder_out: - A 2-D tensor of shape (T, encoder_out_dim) containing the output of + A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of the encoder model. sp: The BPE model. @@ -180,59 +337,130 @@ def greedy_search( context_size = model.decoder.context_size device = model.device - if stream.decoder_out is None: + if streams[0].decoder_out is None: decoder_input = torch.tensor( - [stream.hyp.ys[-context_size:]], + [stream.hyp.ys[-context_size:] for stream in streams], device=device, dtype=torch.int64, ) - stream.decoder_out = model.decoder( + decoder_out = model.decoder( decoder_input, need_pad=False, ).unsqueeze(1) - # stream.decoder_out is of shape (1, 1, decoder_out_dim) + # decoder_out is of shape (N, 1, decoder_out_dim) + else: + decoder_out = torch.cat( + [stream.decoder_out for stream in streams], + dim=0, + ) - assert encoder_out.ndim == 2 + assert encoder_out.ndim == 3 - T = encoder_out.size(0) + T = encoder_out.size(1) for t in range(T): - current_encoder_out = encoder_out[t].reshape( - 1, 1, 1, encoder_out.size(-1) - ) - logits = model.joiner(current_encoder_out, stream.decoder_out) - # logits is of shape (1, 1, 1, vocab_size) - y = logits.argmax().item() - if y == blank_id: - continue - stream.hyp.ys.append(y) + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - decoder_input = torch.tensor( - [stream.hyp.ys[-context_size:]], - device=device, - dtype=torch.int64, - ) + logits = model.joiner(current_encoder_out, decoder_out) + # logits'shape (batch_size, 1, 1, vocab_size) - stream.decoder_out = model.decoder( - decoder_input, - need_pad=False, - ).unsqueeze(1) + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.ys.append(v) + emitted = True - logging.info( - f"Partial result:\n{sp.decode(stream.hyp.ys[context_size:])}" - ) + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp.ys[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, need_pad=False + ).unsqueeze(1) + + for k, s in enumerate(streams): + logging.info( + f"Partial result {k}:\n{sp.decode(s.hyp.ys[context_size:])}" + ) + + decoder_out_list = decoder_out.unbind(dim=0) + + for i, d in enumerate(decoder_out_list): + streams[i].decoder_out = d.unsqueeze(0) -def process_feature_frames( +def process_features( model: nn.Module, - stream: Stream, + features: torch.Tensor, + streams: List[FeatureExtractionStream], sp: spm.SentencePieceProcessor, -): - """Process the feature frames contained in ``stream.feature_frames``. +) -> None: + """Process features for each stream in parallel. + Args: model: The RNN-T model. - stream: - The stream corresponding to the input audio samples. + features: + A 3-D tensor of shape (N, T, C). + streams: + A list of streams of size (N,). + sp: + The BPE model. + """ + assert features.ndim == 3 + assert features.size(0) == len(streams) + batch_size = features.size(0) + + device = model.device + features = features.to(device) + feature_lens = torch.full( + (batch_size,), + fill_value=features.size(1), + device=device, + ) + if streams[0].states is None: + states = None + else: + state_list = [stream.states for stream in streams] + states = stack_states(state_list) + + (encoder_out, encoder_out_lens, states,) = model.encoder.streaming_forward( + features, + feature_lens, + states, + ) + state_list = unstack_states(states) + for i, s in enumerate(state_list): + streams[i].states = s + + greedy_search( + model=model, + streams=streams, + encoder_out=encoder_out, + sp=sp, + ) + + +def decode_batch( + batched_samples: List[torch.Tensor], + model: nn.Module, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> List[str]: + """ + Args: + batched_samples: + A list of 1-D tensors containing the audio samples of each utterance. + model: + The RNN-T model. + params: + It is the return value of :func:`get_params`. sp: The BPE model. """ @@ -241,102 +469,41 @@ def process_feature_frames( right_context_length = model.encoder.right_context_length + # We add 3 here since the subsampling method is using + # ((len - 1) // 2 - 1) // 2) chunk_length = (segment_length + 3) + right_context_length - device = model.device - while len(stream.feature_frames) >= chunk_length: - # a list of tensor, each with a shape (1, feature_dim) - this_chunk = stream.feature_frames[:chunk_length] + batch_size = len(batched_samples) + streaming_audio_samples = StreamingAudioSamples(batched_samples) - stream.feature_frames = stream.feature_frames[segment_length:] - features = torch.cat(this_chunk, dim=0).to(device) # (T, feature_dim) - features = features.unsqueeze(0) # (1, T, feature_dim) - feature_lens = torch.tensor([features.size(1)], device=device) - ( - encoder_out, - encoder_out_lens, - stream.states, - ) = model.encoder.streaming_forward( - features, - feature_lens, - stream.states, + stream_list = StreamList( + batch_size=batch_size, + context_size=params.context_size, + blank_id=params.blank_id, + ) + + while not streaming_audio_samples.done: + samples = streaming_audio_samples.get_next() + stream_list.accept_waveform( + audio_samples=samples, + sampling_rate=params.sampling_rate, ) - greedy_search( - model=model, - stream=stream, - encoder_out=encoder_out[0], - sp=sp, + features, active_streams = stream_list.build_batch( + chunk_length=chunk_length, + segment_length=segment_length, ) - - if stream.feature_extractor.is_last_frame(stream.num_fetched_frames - 1): - assert len(stream.feature_frames) < chunk_length - - if len(stream.feature_frames) > 0: - this_chunk = stream.feature_frames[:chunk_length] - stream.feature_frames = [] - features = torch.cat(this_chunk, dim=0) # (T, feature_dim) - features = features.to(device).unsqueeze(0) # (1, T, feature_dim) - features = torch.nn.functional.pad( - features, - (0, 0, 0, chunk_length - features.size(1)), - value=LOG_EPSILON, - ) - feature_lens = torch.tensor([features.size(1)], device=device) - ( - encoder_out, - encoder_out_lens, - stream.states, - ) = model.encoder.streaming_forward( - features, - feature_lens, - stream.states, - ) - greedy_search( + if features is not None: + process_features( model=model, - stream=stream, - encoder_out=encoder_out[0], + features=features, + streams=active_streams, sp=sp, ) - - -def decode_one_utterance( - audio_samples: torch.Tensor, - model: nn.Module, - stream: Stream, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -): - """Decode one utterance. - Args: - audio_samples: - A 1-D float32 tensor of shape (num_samples,) containing the - audio samples. - model: - The RNN-T model. - feature_extractor: - The feature extractor. - params: - It is the return value of :func:`get_params`. - sp: - The BPE model. - """ - i = 0 - num_samples = audio_samples.size(0) - while i < num_samples: - # Simulate streaming. - this_chunk_num_samples = torch.randint(2000, 5000, (1,)).item() - - thiks_chunk_samples = audio_samples[i : (i + this_chunk_num_samples)] - i += this_chunk_num_samples - - stream.accept_waveform( - sampling_rate=params.sampling_rate, - waveform=thiks_chunk_samples, - ) - process_feature_frames(model=model, stream=stream, sp=sp) - - stream.input_finished() - process_feature_frames(model=model, stream=stream, sp=sp) + results = [] + for s in stream_list.streams: + text = sp.decode(s.hyp.ys[params.context_size :]) + results.append(text) + return results @torch.no_grad() @@ -403,31 +570,41 @@ def main(): test_clean_cuts = librispeech.test_clean_cuts() - for num, cut in enumerate(test_clean_cuts): - logging.info(f"Processing {num}") - stream = Stream( - context_size=model.decoder.context_size, - blank_id=model.decoder.blank_id, - ) + batch_size = 3 + ground_truth = [] + batched_samples = [] + for num, cut in enumerate(test_clean_cuts): audio: np.ndarray = cut.load_audio() # audio.shape: (1, num_samples) assert len(audio.shape) == 2 assert audio.shape[0] == 1, "Should be single channel" assert audio.dtype == np.float32, audio.dtype - assert audio.max() <= 1, "Should be normalized to [-1, 1])" - decode_one_utterance( - audio_samples=torch.from_numpy(audio).squeeze(0).to(device), - model=model, - stream=stream, - params=params, - sp=sp, - ) - logging.info(f"The ground truth is:\n{cut.supervisions[0].text}") - if num >= 2: + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + batched_samples.append(samples) + ground_truth.append(cut.supervisions[0].text) + + if len(batched_samples) >= batch_size: + decoded_results = decode_batch( + batched_samples=batched_samples, + model=model, + params=params, + sp=sp, + ) + s = "\n" + for i, (hyp, ref) in enumerate(zip(decoded_results, ground_truth)): + s += f"hyp {i}:\n{hyp}\n" + s += f"ref {i}:\n{ref}\n\n" + logging.info(s) + batched_samples = [] + ground_truth = [] + # break after processing the first batch for test purposes break - time.sleep(2) # So that you can see the decoded results if __name__ == "__main__": diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py index 90f333694..e5b13df7f 100644 --- a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py @@ -40,7 +40,7 @@ def _create_streaming_feature_extractr() -> OnlineFeature: return OnlineFbank(opts) -class Stream(object): +class FeatureExtractionStream(object): def __init__(self, context_size: int, blank_id: int = 0) -> None: """Context size of the RNN-T decoder model.""" self.feature_extractor = _create_streaming_feature_extractr() @@ -62,6 +62,9 @@ class Stream(object): # corresponding to the decoder input self.hyp.ys[-context_size:] self.decoder_out: Optional[torch.Tensor] = None + # After calling `self.input_finished()`, we set this flag to True + self._done = False + def accept_waveform( self, sampling_rate: float, @@ -97,6 +100,12 @@ class Stream(object): """ self.feature_extractor.input_finished() self._fetch_frames() + self._done = True + + @property + def done(self) -> bool: + """Return True if `self.input_finished()` has been invoked""" + return self._done def _fetch_frames(self) -> None: """Fetch frames from the feature extractor""" diff --git a/egs/librispeech/ASR/transducer_emformer/test_emformer.py b/egs/librispeech/ASR/transducer_emformer/test_emformer.py index d8c7b37e2..239ed24ac 100755 --- a/egs/librispeech/ASR/transducer_emformer/test_emformer.py +++ b/egs/librispeech/ASR/transducer_emformer/test_emformer.py @@ -25,7 +25,7 @@ To run this file, do: import warnings import torch -from emformer import Emformer +from emformer import Emformer, stack_states, unstack_states def test_emformer(): @@ -65,8 +65,41 @@ def test_emformer(): print(f"Number of encoder parameters: {num_param}") +def test_emformer_streaming_forward(): + N = 3 + C = 80 + + output_dim = 500 + + encoder = Emformer( + num_features=C, + output_dim=output_dim, + d_model=512, + nhead=8, + dim_feedforward=2048, + num_encoder_layers=20, + segment_length=16, + left_context_length=120, + right_context_length=4, + vgg_frontend=False, + ) + + x = torch.rand(N, 23, C) + x_lens = torch.full((N,), 23) + y, y_lens, states = encoder.streaming_forward(x=x, x_lens=x_lens) + + state_list = unstack_states(states) + states2 = stack_states(state_list) + + for ss, ss2 in zip(states, states2): + for s, s2 in zip(ss, ss2): + assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}" + + +@torch.no_grad() def main(): - test_emformer() + # test_emformer() + test_emformer_streaming_forward() if __name__ == "__main__": diff --git a/egs/librispeech/ASR/transducer_emformer/test_streaming_feature_extractor.py b/egs/librispeech/ASR/transducer_emformer/test_streaming_feature_extractor.py index 502668e83..4ce9c3284 100755 --- a/egs/librispeech/ASR/transducer_emformer/test_streaming_feature_extractor.py +++ b/egs/librispeech/ASR/transducer_emformer/test_streaming_feature_extractor.py @@ -24,11 +24,11 @@ To run this file, do: """ import torch -from streaming_feature_extractor import Stream +from streaming_feature_extractor import FeatureExtractionStream def test_streaming_feature_extractor(): - stream = Stream(context_size=2, blank_id=0) + stream = FeatureExtractionStream(context_size=2, blank_id=0) samples = torch.rand(16000) start = 0 while True: