diff --git a/.flake8 b/.flake8 index dd9239b2d..019dcc9da 100644 --- a/.flake8 +++ b/.flake8 @@ -15,3 +15,7 @@ exclude = **/data/**, icefall/shared/make_kn_lm.py, icefall/__init__.py + +ignore = + # E203 whitespace before ':' + E203, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py index 7c5a93a86..fbb30e057 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py @@ -32,13 +32,16 @@ class Joiner(nn.Module): """ Args: encoder_out: - Output from the encoder. Its shape is (N, T, s_range, C). + Output from the encoder. Its shape is (N, T, s_range, C) for + training and (N, C) for streaming decoding. decoder_out: - Output from the decoder. Its shape is (N, T, s_range, C). + Output from the decoder. Its shape is (N, T, s_range, C) for + training and (N, C) for streaming decoding. Returns: Return a tensor of shape (N, T, s_range, C). """ - assert encoder_out.ndim == decoder_out.ndim == 4 + assert encoder_out.ndim == decoder_out.ndim + assert encoder_out.ndim in (2, 4) assert encoder_out.shape == decoder_out.shape logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/transducer_emformer/emformer.py b/egs/librispeech/ASR/transducer_emformer/emformer.py index b3693d660..0029b42af 100644 --- a/egs/librispeech/ASR/transducer_emformer/emformer.py +++ b/egs/librispeech/ASR/transducer_emformer/emformer.py @@ -27,6 +27,75 @@ from torchaudio.models import Emformer as _Emformer LOG_EPSILON = math.log(1e-10) +def unstack_states( + states: List[List[torch.Tensor]], +) -> List[List[List[torch.Tensor]]]: + """Unstack the emformer state corresponding to a batch of utterances + into a list of states, were the i-th entry is the state from the i-th + utterance in the batch. + + Args: + states: + A list-of-list of tensors. ``len(states)`` equals to number of + layers in the emformer. ``states[i]]`` contains the states for + the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape + ``(T, N, C)`` or a 2-D tensor of shape ``(C, N)`` + """ + batch_size = states[0][0].size(1) + num_layers = len(states) + + ans = [None] * batch_size + for i in range(batch_size): + ans[i] = [[] for _ in range(num_layers)] + + for li, layer in enumerate(states): + for s in layer: + s_list = s.unbind(dim=1) + # We will use stack(dim=1) later in stack_states() + for bi, b in enumerate(ans): + b[li].append(s_list[bi]) + return ans + + +def stack_states( + state_list: List[List[List[torch.Tensor]]], +) -> List[List[torch.Tensor]]: + """Stack list of emformer states that correspond to separate utterances + into a single emformer state so that it can be used as an input for + emformer when those utterances are formed into a batch. + + Note: + It is the inverse of :func:`unstack_states`. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the emformer model for a single utterance. + Returns: + Return a new state corresponding to a batch of utterances. + See the input argument of :func:`unstack_states` for the meaning + of the returned tensor. + """ + batch_size = len(state_list) + ans = [] + for layer in state_list[0]: + # layer is a list of tensors + if batch_size > 1: + ans.append([[s] for s in layer]) + # Note: We will stack ans[layer][s][] later to get ans[layer][s] + else: + ans.append([s.unsqueeze(1) for s in layer]) + + for b, states in enumerate(state_list[1:], 1): + for li, layer in enumerate(states): + for si, s in enumerate(layer): + ans[li][si].append(s) + if b == batch_size - 1: + ans[li][si] = torch.stack(ans[li][si], dim=1) + # We will use unbind(dim=1) later in unstack_states() + return ans + + class Emformer(EncoderInterface): """This is just a simple wrapper around torchaudio.models.Emformer. We may replace it with our own implementation some time later. @@ -63,11 +132,11 @@ class Emformer(EncoderInterface): num_encoder_layers: Number of encoder layers. segment_length: - Number of frames per segment. + Number of frames per segment before subsampling. left_context_length: - Number of frames in the left context. + Number of frames in the left context before subsampling. right_context_length: - Number of frames in the right context. + Number of frames in the right context before subsampling. max_memory_size: TODO. dropout: @@ -94,6 +163,7 @@ class Emformer(EncoderInterface): else: self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.segment_length = segment_length self.right_context_length = right_context_length assert right_context_length % subsampling_factor == 0 diff --git a/egs/librispeech/ASR/transducer_emformer/export.py b/egs/librispeech/ASR/transducer_emformer/export.py new file mode 100755 index 000000000..9c665ea1a --- /dev/null +++ b/egs/librispeech/ASR/transducer_emformer/export.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./transducer_emformer/export.py \ + --exp-dir ./transducer_emformer/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `transducer_emformer/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./transducer_emformer/decode.py \ + --exp-dir ./transducer_emformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 1000 \ + --bpe-model data/lang_bpe_500/bpe.model +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + assert args.jit is False, "Support torchscript will be added later" + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.eval() + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py index 0b21b3f59..bb71310b7 100755 --- a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py @@ -18,16 +18,16 @@ import argparse import logging -import time from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Tuple -import kaldifeat import numpy as np import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule +from emformer import LOG_EPSILON, stack_states, unstack_states +from streaming_feature_extractor import FeatureExtractionStream from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -147,10 +147,10 @@ def get_parser(): ) parser.add_argument( - "--sample-rate", - type=int, + "--sampling-rate", + type=float, default=16000, - help="The sample rate of the input sound file", + help="Sample rate of the audio", ) add_model_arguments(parser) @@ -158,115 +158,352 @@ def get_parser(): return parser -def get_feature_extractor( - params: AttributeDict, -) -> kaldifeat.Fbank: - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = params.device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = True - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim +class StreamingAudioSamples(object): + """This class takes as input a list of audio samples and returns + them in a streaming fashion. + """ - return kaldifeat.Fbank(opts) + def __init__(self, samples: List[torch.Tensor]) -> None: + """ + Args: + samples: + A list of audio samples. Each entry is a 1-D tensor of dtype + torch.float32, containing the audio samples of an utterance. + """ + self.samples = samples + self.cur_indexes = [0] * len(self.samples) + + @property + def done(self) -> bool: + """Return True if all samples have been processed. + Return False otherwise. + """ + for i, samples in zip(self.cur_indexes, self.samples): + if i < samples.numel(): + return False + return True + + def get_next(self) -> List[torch.Tensor]: + """Return a list of audio samples. Each entry may have different + lengths. It is OK if an entry contains no samples at all, which + means it reaches the end of the utterance. + """ + ans = [] + + num = [1024] * len(self.samples) + + for i in range(len(self.samples)): + start = self.cur_indexes[i] + end = start + num[i] + self.cur_indexes[i] = end + + s = self.samples[i][start:end] + ans.append(s) + + return ans -def decode_one_utterance( - audio_samples: torch.Tensor, +class StreamList(object): + def __init__( + self, + batch_size: int, + context_size: int, + blank_id: int, + ): + """ + Args: + batch_size: + Size of this batch. + context_size: + Context size of the RNN-T decoder model. + blank_id: + The ID of the blank symbol of the BPE model. + """ + self.streams = [ + FeatureExtractionStream( + context_size=context_size, blank_id=blank_id + ) + for _ in range(batch_size) + ] + + @property + def done(self) -> bool: + """Return True if all streams have reached end of utterance. + That is, no more audio samples are available for all utterances. + """ + return all(stream.done for stream in self.streams) + + def accept_waveform( + self, + audio_samples: List[torch.Tensor], + sampling_rate: float, + ): + """Feeed audio samples to each stream. + Args: + audio_samples: + A list of 1-D tensors containing the audio samples for each + utterance in the batch. If an entry is empty, it means + end-of-utterance has been reached. + sampling_rate: + Sampling rate of the given audio samples. + """ + assert len(audio_samples) == len(self.streams) + for stream, samples in zip(self.streams, audio_samples): + + if stream.done: + assert samples.numel() == 0 + continue + + stream.accept_waveform( + sampling_rate=sampling_rate, + waveform=samples, + ) + + if samples.numel() == 0: + stream.input_finished() + + def build_batch( + self, + chunk_length: int, + segment_length: int, + ) -> Tuple[Optional[torch.Tensor], Optional[List[FeatureExtractionStream]]]: + """ + Args: + chunk_length: + Number of frames for each chunk. It equals to + ``segment_length + right_context_length``. + segment_length + Number of frames for each segment. + Returns: + Return a tuple containing: + - features, a 3-D tensor of shape ``(num_active_streams, T, C)`` + - active_streams, a list of active streams. We say a stream is + active when it has enough feature frames to be fed into the + encoder model. + """ + feature_list = [] + stream_list = [] + for stream in self.streams: + if len(stream.feature_frames) >= chunk_length: + # this_chunk is a list of tensors, each of which + # has a shape (1, feature_dim) + chunk = stream.feature_frames[:chunk_length] + stream.feature_frames = stream.feature_frames[segment_length:] + features = torch.cat(chunk, dim=0) + feature_list.append(features) + stream_list.append(stream) + elif stream.done and len(stream.feature_frames) > 0: + chunk = stream.feature_frames[:chunk_length] + stream.feature_frames = [] + features = torch.cat(chunk, dim=0) + features = torch.nn.functional.pad( + features, + (0, 0, 0, chunk_length - features.size(0)), + mode="constant", + value=LOG_EPSILON, + ) + feature_list.append(features) + stream_list.append(stream) + + if len(feature_list) == 0: + return None, None + + features = torch.stack(feature_list, dim=0) + return features, stream_list + + +def greedy_search( model: nn.Module, - fbank: kaldifeat.Fbank, - params: AttributeDict, + streams: List[FeatureExtractionStream], + encoder_out: torch.Tensor, sp: spm.SentencePieceProcessor, ): - """Decode one utterance. + """ Args: - audio_samples: - A 1-D float32 tensor of shape (num_samples,) containing the normalized - audio samples. Normalized means the samples is in the range [-1, 1]. model: The RNN-T model. - feature_extractor: - The feature extractor. + stream: + A stream object. + encoder_out: + A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of + the encoder model. + sp: + The BPE model. + """ + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + + if streams[0].decoder_out is None: + decoder_input = torch.tensor( + [stream.hyp.ys[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ).squeeze(1) + # decoder_out is of shape (N, decoder_out_dim) + else: + decoder_out = torch.stack( + [stream.decoder_out for stream in streams], + dim=0, + ) + + assert encoder_out.ndim == 3 + + T = encoder_out.size(1) + for t in range(T): + current_encoder_out = encoder_out[:, t] + # current_encoder_out's shape: (batch_size, encoder_out_dim) + + logits = model.joiner(current_encoder_out, decoder_out) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.ys.append(v) + emitted = True + + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp.ys[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder(decoder_input, need_pad=False).squeeze( + 1 + ) + + for k, s in enumerate(streams): + logging.info( + f"Partial result {k}:\n{sp.decode(s.hyp.ys[context_size:])}" + ) + + decoder_out_list = decoder_out.unbind(dim=0) + + for i, d in enumerate(decoder_out_list): + streams[i].decoder_out = d + + +def process_features( + model: nn.Module, + features: torch.Tensor, + streams: List[FeatureExtractionStream], + sp: spm.SentencePieceProcessor, +) -> None: + """Process features for each stream in parallel. + + Args: + model: + The RNN-T model. + features: + A 3-D tensor of shape (N, T, C). + streams: + A list of streams of size (N,). + sp: + The BPE model. + """ + assert features.ndim == 3 + assert features.size(0) == len(streams) + batch_size = features.size(0) + + device = model.device + features = features.to(device) + feature_lens = torch.full( + (batch_size,), + fill_value=features.size(1), + device=device, + ) + + # Caution: It has a limitation as it assumes that + # if one of the stream has an empty state, then all other + # streams also have empty states. + if streams[0].states is None: + states = None + else: + state_list = [stream.states for stream in streams] + states = stack_states(state_list) + + (encoder_out, encoder_out_lens, states,) = model.encoder.streaming_forward( + features, + feature_lens, + states, + ) + state_list = unstack_states(states) + for i, s in enumerate(state_list): + streams[i].states = s + + greedy_search( + model=model, + streams=streams, + encoder_out=encoder_out, + sp=sp, + ) + + +def decode_batch( + batched_samples: List[torch.Tensor], + model: nn.Module, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> List[str]: + """ + Args: + batched_samples: + A list of 1-D tensors containing the audio samples of each utterance. + model: + The RNN-T model. params: It is the return value of :func:`get_params`. sp: The BPE model. """ - sample_rate = params.sample_rate - frame_shift = sample_rate * fbank.opts.frame_opts.frame_shift_ms / 1000 + # number of frames before subsampling + segment_length = model.encoder.segment_length - frame_shift = int(frame_shift) # number of samples + right_context_length = model.encoder.right_context_length - # Note: We add 3 here because the subsampling method ((n-1)//2-1))//2 - # is not equal to n//4. We will switch to a subsampling method that - # satisfies n//4, where n is the number of input frames. - segment_length = (params.segment_length + 3) * frame_shift + # We add 3 here since the subsampling method is using + # ((len - 1) // 2 - 1) // 2) + chunk_length = (segment_length + 3) + right_context_length - right_context_length = params.right_context_length * frame_shift - chunk_size = segment_length + right_context_length + batch_size = len(batched_samples) + streaming_audio_samples = StreamingAudioSamples(batched_samples) - opts = fbank.opts.frame_opts - chunk_size += ( - (opts.frame_length_ms - opts.frame_shift_ms) / 1000 * sample_rate + stream_list = StreamList( + batch_size=batch_size, + context_size=params.context_size, + blank_id=params.blank_id, ) - chunk_size = int(chunk_size) - - states: Optional[List[List[torch.Tensor]]] = None - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - - device = model.device - - hyp = [blank_id] * context_size - - decoder_input = torch.tensor(hyp, device=device, dtype=torch.int64).reshape( - 1, context_size - ) - - decoder_out = model.decoder(decoder_input, need_pad=False) - - i = 0 - num_samples = audio_samples.size(0) - while i < num_samples: - # Note: The current approach of computing the features is not ideal - # since it re-computes the features for the right context. - chunk = audio_samples[i : i + chunk_size] # noqa - i += segment_length - if chunk.size(0) < chunk_size: - chunk = torch.nn.functional.pad( - chunk, pad=(0, chunk_size - chunk.size(0)) - ) - features = fbank(chunk) - feature_lens = torch.tensor([features.size(0)], device=params.device) - - features = features.unsqueeze(0) # (1, T, C) - - encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( - features, - feature_lens, - states, + while not streaming_audio_samples.done: + samples = streaming_audio_samples.get_next() + stream_list.accept_waveform( + audio_samples=samples, + sampling_rate=params.sampling_rate, ) - for t in range(encoder_out_lens.item()): - # fmt: off - current_encoder_out = encoder_out[0:1, t:t+1, :].unsqueeze(2) - # fmt: on - logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) - # logits is (1, 1, 1, vocab_size) - y = logits.argmax().item() - if y == blank_id: - continue - - hyp.append(y) - - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device, dtype=torch.int64 - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - logging.info(f"Partial result:\n{sp.decode(hyp[context_size:])}") + features, active_streams = stream_list.build_batch( + chunk_length=chunk_length, + segment_length=segment_length, + ) + if features is not None: + process_features( + model=model, + features=features, + streams=active_streams, + sp=sp, + ) + results = [] + for s in stream_list.streams: + text = sp.decode(s.hyp.ys[params.context_size :]) + results.append(text) + return results @torch.no_grad() @@ -333,30 +570,43 @@ def main(): test_clean_cuts = librispeech.test_clean_cuts() - fbank = get_feature_extractor(params) + batch_size = 3 + ground_truth = [] + batched_samples = [] for num, cut in enumerate(test_clean_cuts): - logging.info("Processing {num}") - audio: np.ndarray = cut.load_audio() # audio.shape: (1, num_samples) assert len(audio.shape) == 2 assert audio.shape[0] == 1, "Should be single channel" assert audio.dtype == np.float32, audio.dtype - assert audio.max() <= 1, "Should be normalized to [-1, 1])" - decode_one_utterance( - audio_samples=torch.from_numpy(audio).squeeze(0).to(device), - model=model, - fbank=fbank, - params=params, - sp=sp, - ) - logging.info(f"The ground truth is:\n{cut.supervisions[0].text}") - if num >= 0: + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + batched_samples.append(samples) + ground_truth.append(cut.supervisions[0].text) + + if len(batched_samples) >= batch_size: + decoded_results = decode_batch( + batched_samples=batched_samples, + model=model, + params=params, + sp=sp, + ) + s = "\n" + for i, (hyp, ref) in enumerate(zip(decoded_results, ground_truth)): + s += f"hyp {i}:\n{hyp}\n" + s += f"ref {i}:\n{ref}\n\n" + logging.info(s) + batched_samples = [] + ground_truth = [] + # break after processing the first batch for test purposes break - time.sleep(2) # So that you can see the decoded results if __name__ == "__main__": + torch.manual_seed(20220410) main() diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py new file mode 100644 index 000000000..b20f6502f --- /dev/null +++ b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py @@ -0,0 +1,116 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +from beam_search import Hypothesis +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def _create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +class FeatureExtractionStream(object): + def __init__(self, context_size: int, blank_id: int = 0) -> None: + """Context size of the RNN-T decoder model.""" + self.feature_extractor = _create_streaming_feature_extractor() + self.hyp = Hypothesis( + ys=([blank_id] * context_size), + log_prob=torch.tensor([0.0]), + ) # for greedy search, will extend it to beam search + + # It contains a list of 1-D tensors representing the feature frames. + self.feature_frames: List[torch.Tensor] = [] + + self.num_fetched_frames = 0 + + # For the emformer model, it contains the states of each + # encoder layer. + self.states: Optional[List[List[torch.Tensor]]] = None + + # For the RNN-T decoder, it contains the decoder output + # corresponding to the decoder input self.hyp.ys[-context_size:] + # Its shape is (decoder_out_dim,) + self.decoder_out: Optional[torch.Tensor] = None + + # After calling `self.input_finished()`, we set this flag to True + self._done = False + + def accept_waveform( + self, + sampling_rate: float, + waveform: torch.Tensor, + ) -> None: + """Feed audio samples to the feature extractor and compute features + if there are enough samples available. + + Caution: + The range of the audio samples should match the one used in the + training. That is, if you use the range [-1, 1] in the training, then + the input audio samples should also be normalized to [-1, 1]. + + Args + sampling_rate: + The sampling rate of the input audio samples. It is used for sanity + check to ensure that the input sampling rate equals to the one + used in the extractor. If they are not equal, then no resampling + will be performed; instead an error will be thrown. + waveform: + A 1-D torch tensor of dtype torch.float32 containing audio samples. + It should be on CPU. + """ + self.feature_extractor.accept_waveform( + sampling_rate=sampling_rate, + waveform=waveform, + ) + self._fetch_frames() + + def input_finished(self) -> None: + """Signal that no more audio samples available and the feature + extractor should flush the buffered samples to compute frames. + """ + self.feature_extractor.input_finished() + self._fetch_frames() + self._done = True + + @property + def done(self) -> bool: + """Return True if `self.input_finished()` has been invoked""" + return self._done + + def _fetch_frames(self) -> None: + """Fetch frames from the feature extractor""" + while self.num_fetched_frames < self.feature_extractor.num_frames_ready: + frame = self.feature_extractor.get_frame(self.num_fetched_frames) + self.feature_frames.append(frame) + self.num_fetched_frames += 1 diff --git a/egs/librispeech/ASR/transducer_emformer/test_emformer.py b/egs/librispeech/ASR/transducer_emformer/test_emformer.py index d8c7b37e2..239ed24ac 100755 --- a/egs/librispeech/ASR/transducer_emformer/test_emformer.py +++ b/egs/librispeech/ASR/transducer_emformer/test_emformer.py @@ -25,7 +25,7 @@ To run this file, do: import warnings import torch -from emformer import Emformer +from emformer import Emformer, stack_states, unstack_states def test_emformer(): @@ -65,8 +65,41 @@ def test_emformer(): print(f"Number of encoder parameters: {num_param}") +def test_emformer_streaming_forward(): + N = 3 + C = 80 + + output_dim = 500 + + encoder = Emformer( + num_features=C, + output_dim=output_dim, + d_model=512, + nhead=8, + dim_feedforward=2048, + num_encoder_layers=20, + segment_length=16, + left_context_length=120, + right_context_length=4, + vgg_frontend=False, + ) + + x = torch.rand(N, 23, C) + x_lens = torch.full((N,), 23) + y, y_lens, states = encoder.streaming_forward(x=x, x_lens=x_lens) + + state_list = unstack_states(states) + states2 = stack_states(state_list) + + for ss, ss2 in zip(states, states2): + for s, s2 in zip(ss, ss2): + assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}" + + +@torch.no_grad() def main(): - test_emformer() + # test_emformer() + test_emformer_streaming_forward() if __name__ == "__main__": diff --git a/egs/librispeech/ASR/transducer_emformer/test_streaming_feature_extractor.py b/egs/librispeech/ASR/transducer_emformer/test_streaming_feature_extractor.py new file mode 100755 index 000000000..4ce9c3284 --- /dev/null +++ b/egs/librispeech/ASR/transducer_emformer/test_streaming_feature_extractor.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer_emformer/test_streaming_feature_extractor.py +""" + +import torch +from streaming_feature_extractor import FeatureExtractionStream + + +def test_streaming_feature_extractor(): + stream = FeatureExtractionStream(context_size=2, blank_id=0) + samples = torch.rand(16000) + start = 0 + while True: + n = torch.randint(50, 500, (1,)).item() + end = start + n + this_chunk = samples[start:end] + start = end + + if len(this_chunk) == 0: + break + stream.accept_waveform(sampling_rate=16000, waveform=this_chunk) + print(len(stream.feature_frames)) + stream.input_finished() + print(len(stream.feature_frames)) + + +def main(): + test_streaming_feature_extractor() + + +if __name__ == "__main__": + main()