diff --git a/.flake8 b/.flake8 index 229cf1d6c..523ac9296 100644 --- a/.flake8 +++ b/.flake8 @@ -14,3 +14,7 @@ exclude = .git, **/data/**, icefall/shared/make_kn_lm.py + +ignore = + # E203 whitespace before ':' + E203, diff --git a/egs/librispeech/ASR/transducer_emformer/emformer.py b/egs/librispeech/ASR/transducer_emformer/emformer.py index b3693d660..631ff43fb 100644 --- a/egs/librispeech/ASR/transducer_emformer/emformer.py +++ b/egs/librispeech/ASR/transducer_emformer/emformer.py @@ -63,11 +63,11 @@ class Emformer(EncoderInterface): num_encoder_layers: Number of encoder layers. segment_length: - Number of frames per segment. + Number of frames per segment before subsampling. left_context_length: - Number of frames in the left context. + Number of frames in the left context before subsampling. right_context_length: - Number of frames in the right context. + Number of frames in the right context before subsampling. max_memory_size: TODO. dropout: @@ -94,6 +94,7 @@ class Emformer(EncoderInterface): else: self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.segment_length = segment_length self.right_context_length = right_context_length assert right_context_length % subsampling_factor == 0 diff --git a/egs/librispeech/ASR/transducer_emformer/export.py b/egs/librispeech/ASR/transducer_emformer/export.py new file mode 100755 index 000000000..9c665ea1a --- /dev/null +++ b/egs/librispeech/ASR/transducer_emformer/export.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./transducer_emformer/export.py \ + --exp-dir ./transducer_emformer/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `transducer_emformer/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./transducer_emformer/decode.py \ + --exp-dir ./transducer_emformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 1000 \ + --bpe-model data/lang_bpe_500/bpe.model +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + assert args.jit is False, "Support torchscript will be added later" + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.eval() + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py index 0b21b3f59..aca6e444d 100755 --- a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py @@ -20,14 +20,14 @@ import argparse import logging import time from pathlib import Path -from typing import List, Optional -import kaldifeat import numpy as np import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule +from emformer import LOG_EPSILON +from streaming_feature_extractor import Stream from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -147,10 +147,10 @@ def get_parser(): ) parser.add_argument( - "--sample-rate", - type=int, + "--sampling-rate", + type=float, default=16000, - help="The sample rate of the input sound file", + help="Sample rate of the audio", ) add_model_arguments(parser) @@ -158,32 +158,159 @@ def get_parser(): return parser -def get_feature_extractor( - params: AttributeDict, -) -> kaldifeat.Fbank: - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = params.device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = True - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim +def greedy_search( + model: nn.Module, + stream: Stream, + encoder_out: torch.Tensor, + sp: spm.SentencePieceProcessor, +): + """ + Args: + model: + The RNN-T model. + stream: + A stream object. + encoder_out: + A 2-D tensor of shape (T, encoder_out_dim) containing the output of + the encoder model. + sp: + The BPE model. + """ + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device - return kaldifeat.Fbank(opts) + if stream.decoder_out is None: + decoder_input = torch.tensor( + [stream.hyp.ys[-context_size:]], + device=device, + dtype=torch.int64, + ) + stream.decoder_out = model.decoder( + decoder_input, + need_pad=False, + ).unsqueeze(1) + # stream.decoder_out is of shape (1, 1, decoder_out_dim) + + assert encoder_out.ndim == 2 + + T = encoder_out.size(0) + for t in range(T): + current_encoder_out = encoder_out[t].reshape( + 1, 1, 1, encoder_out.size(-1) + ) + logits = model.joiner(current_encoder_out, stream.decoder_out) + # logits is of shape (1, 1, 1, vocab_size) + y = logits.argmax().item() + if y == blank_id: + continue + stream.hyp.ys.append(y) + + decoder_input = torch.tensor( + [stream.hyp.ys[-context_size:]], + device=device, + dtype=torch.int64, + ) + + stream.decoder_out = model.decoder( + decoder_input, + need_pad=False, + ).unsqueeze(1) + + logging.info( + f"Partial result:\n{sp.decode(stream.hyp.ys[context_size:])}" + ) + + +def process_feature_frames( + model: nn.Module, + stream: Stream, + sp: spm.SentencePieceProcessor, +): + """Process the feature frames contained in ``stream.feature_frames``. + Args: + model: + The RNN-T model. + stream: + The stream corresponding to the input audio samples. + sp: + The BPE model. + """ + # number of frames before subsampling + segment_length = model.encoder.segment_length + + right_context_length = model.encoder.right_context_length + + chunk_length = (segment_length + 3) + right_context_length + + device = model.device + while len(stream.feature_frames) >= chunk_length: + # a list of tensor, each with a shape (1, feature_dim) + this_chunk = stream.feature_frames[:chunk_length] + + stream.feature_frames = stream.feature_frames[segment_length:] + features = torch.cat(this_chunk, dim=0).to(device) # (T, feature_dim) + features = features.unsqueeze(0) # (1, T, feature_dim) + feature_lens = torch.tensor([features.size(1)], device=device) + ( + encoder_out, + encoder_out_lens, + stream.states, + ) = model.encoder.streaming_forward( + features, + feature_lens, + stream.states, + ) + greedy_search( + model=model, + stream=stream, + encoder_out=encoder_out[0], + sp=sp, + ) + + if stream.feature_extractor.is_last_frame(stream.num_fetched_frames - 1): + assert len(stream.feature_frames) < chunk_length + + if len(stream.feature_frames) > 0: + this_chunk = stream.feature_frames[:chunk_length] + stream.feature_frames = [] + features = torch.cat(this_chunk, dim=0) # (T, feature_dim) + features = features.to(device).unsqueeze(0) # (1, T, feature_dim) + features = torch.nn.functional.pad( + features, + (0, 0, 0, chunk_length - features.size(1)), + value=LOG_EPSILON, + ) + feature_lens = torch.tensor([features.size(1)], device=device) + ( + encoder_out, + encoder_out_lens, + stream.states, + ) = model.encoder.streaming_forward( + features, + feature_lens, + stream.states, + ) + greedy_search( + model=model, + stream=stream, + encoder_out=encoder_out[0], + sp=sp, + ) def decode_one_utterance( audio_samples: torch.Tensor, model: nn.Module, - fbank: kaldifeat.Fbank, + stream: Stream, params: AttributeDict, sp: spm.SentencePieceProcessor, ): """Decode one utterance. Args: audio_samples: - A 1-D float32 tensor of shape (num_samples,) containing the normalized - audio samples. Normalized means the samples is in the range [-1, 1]. + A 1-D float32 tensor of shape (num_samples,) containing the + audio samples. model: The RNN-T model. feature_extractor: @@ -193,80 +320,23 @@ def decode_one_utterance( sp: The BPE model. """ - sample_rate = params.sample_rate - frame_shift = sample_rate * fbank.opts.frame_opts.frame_shift_ms / 1000 - - frame_shift = int(frame_shift) # number of samples - - # Note: We add 3 here because the subsampling method ((n-1)//2-1))//2 - # is not equal to n//4. We will switch to a subsampling method that - # satisfies n//4, where n is the number of input frames. - segment_length = (params.segment_length + 3) * frame_shift - - right_context_length = params.right_context_length * frame_shift - chunk_size = segment_length + right_context_length - - opts = fbank.opts.frame_opts - chunk_size += ( - (opts.frame_length_ms - opts.frame_shift_ms) / 1000 * sample_rate - ) - - chunk_size = int(chunk_size) - - states: Optional[List[List[torch.Tensor]]] = None - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - - device = model.device - - hyp = [blank_id] * context_size - - decoder_input = torch.tensor(hyp, device=device, dtype=torch.int64).reshape( - 1, context_size - ) - - decoder_out = model.decoder(decoder_input, need_pad=False) - i = 0 num_samples = audio_samples.size(0) while i < num_samples: - # Note: The current approach of computing the features is not ideal - # since it re-computes the features for the right context. - chunk = audio_samples[i : i + chunk_size] # noqa - i += segment_length - if chunk.size(0) < chunk_size: - chunk = torch.nn.functional.pad( - chunk, pad=(0, chunk_size - chunk.size(0)) - ) - features = fbank(chunk) - feature_lens = torch.tensor([features.size(0)], device=params.device) + # Simulate streaming. + this_chunk_num_samples = torch.randint(2000, 5000, (1,)).item() - features = features.unsqueeze(0) # (1, T, C) + thiks_chunk_samples = audio_samples[i : (i + this_chunk_num_samples)] + i += this_chunk_num_samples - encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( - features, - feature_lens, - states, + stream.accept_waveform( + sampling_rate=params.sampling_rate, + waveform=thiks_chunk_samples, ) - for t in range(encoder_out_lens.item()): - # fmt: off - current_encoder_out = encoder_out[0:1, t:t+1, :].unsqueeze(2) - # fmt: on - logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) - # logits is (1, 1, 1, vocab_size) - y = logits.argmax().item() - if y == blank_id: - continue + process_feature_frames(model=model, stream=stream, sp=sp) - hyp.append(y) - - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device, dtype=torch.int64 - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - logging.info(f"Partial result:\n{sp.decode(hyp[context_size:])}") + stream.input_finished() + process_feature_frames(model=model, stream=stream, sp=sp) @torch.no_grad() @@ -333,10 +403,12 @@ def main(): test_clean_cuts = librispeech.test_clean_cuts() - fbank = get_feature_extractor(params) - for num, cut in enumerate(test_clean_cuts): - logging.info("Processing {num}") + logging.info(f"Processing {num}") + stream = Stream( + context_size=model.decoder.context_size, + blank_id=model.decoder.blank_id, + ) audio: np.ndarray = cut.load_audio() # audio.shape: (1, num_samples) @@ -347,16 +419,17 @@ def main(): decode_one_utterance( audio_samples=torch.from_numpy(audio).squeeze(0).to(device), model=model, - fbank=fbank, + stream=stream, params=params, sp=sp, ) logging.info(f"The ground truth is:\n{cut.supervisions[0].text}") - if num >= 0: + if num >= 2: break time.sleep(2) # So that you can see the decoded results if __name__ == "__main__": + torch.manual_seed(20220410) main() diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py new file mode 100644 index 000000000..90f333694 --- /dev/null +++ b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py @@ -0,0 +1,106 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +from beam_search import Hypothesis +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def _create_streaming_feature_extractr() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +class Stream(object): + def __init__(self, context_size: int, blank_id: int = 0) -> None: + """Context size of the RNN-T decoder model.""" + self.feature_extractor = _create_streaming_feature_extractr() + self.hyp = Hypothesis( + ys=([blank_id] * context_size), + log_prob=torch.tensor([0.0]), + ) # for greedy search, will extend it to beam search + + # It contains a list of 1-D tensors representing the feature frames. + self.feature_frames: List[torch.Tensor] = [] + + self.num_fetched_frames = 0 + + # For the emformer model, it contains the states of each + # encoder layer. + self.states: Optional[List[List[torch.Tensor]]] = None + + # For the RNN-T decoder, it contains the decoder output + # corresponding to the decoder input self.hyp.ys[-context_size:] + self.decoder_out: Optional[torch.Tensor] = None + + def accept_waveform( + self, + sampling_rate: float, + waveform: torch.Tensor, + ) -> None: + """Feed audio samples to the feature extractor and compute features + if there are enough samples available. + + Caution: + The range of the audio samples should match the one used in the + training. That is, if you use the range [-1, 1] in the training, then + the input audio samples should also be normalized to [-1, 1]. + + Args + sampling_rate: + The sampling rate of the input audio samples. It is used for sanity + check to ensure that the input sampling rate equals to the one + used in the extractor. If they are not equal, then no resampling + will be performed; instead an error will be thrown. + waveform: + A 1-D torch tensor of dtype torch.float32 containing audio samples. + It should be on CPU. + """ + self.feature_extractor.accept_waveform( + sampling_rate=sampling_rate, + waveform=waveform, + ) + self._fetch_frames() + + def input_finished(self) -> None: + """Signal that no more audio samples available and the feature + extractor should flush the buffered samples to compute frames. + """ + self.feature_extractor.input_finished() + self._fetch_frames() + + def _fetch_frames(self) -> None: + """Fetch frames from the feature extractor""" + while self.num_fetched_frames < self.feature_extractor.num_frames_ready: + frame = self.feature_extractor.get_frame(self.num_fetched_frames) + self.feature_frames.append(frame) + self.num_fetched_frames += 1 diff --git a/egs/librispeech/ASR/transducer_emformer/test_streaming_feature_extractor.py b/egs/librispeech/ASR/transducer_emformer/test_streaming_feature_extractor.py new file mode 100755 index 000000000..502668e83 --- /dev/null +++ b/egs/librispeech/ASR/transducer_emformer/test_streaming_feature_extractor.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer_emformer/test_streaming_feature_extractor.py +""" + +import torch +from streaming_feature_extractor import Stream + + +def test_streaming_feature_extractor(): + stream = Stream(context_size=2, blank_id=0) + samples = torch.rand(16000) + start = 0 + while True: + n = torch.randint(50, 500, (1,)).item() + end = start + n + this_chunk = samples[start:end] + start = end + + if len(this_chunk) == 0: + break + stream.accept_waveform(sampling_rate=16000, waveform=this_chunk) + print(len(stream.feature_frames)) + stream.input_finished() + print(len(stream.feature_frames)) + + +def main(): + test_streaming_feature_extractor() + + +if __name__ == "__main__": + main()