From 6f64a0ed8d54c94bae99b3f99fe24e80db1f312d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 1 Apr 2022 19:07:34 +0800 Subject: [PATCH] Support streaming decoding. --- .../beam_search.py | 13 +- .../ASR/transducer_emformer/beam_search.py | 1 + .../ASR/transducer_emformer/decode.py | 549 ++++++++++++++++++ .../ASR/transducer_emformer/emformer.py | 41 +- .../transducer_emformer/streaming_decode.py | 362 ++++++++++++ .../ASR/transducer_emformer/test_emformer.py | 4 +- .../ASR/transducer_emformer/train.py | 68 ++- 7 files changed, 1023 insertions(+), 15 deletions(-) create mode 120000 egs/librispeech/ASR/transducer_emformer/beam_search.py create mode 100755 egs/librispeech/ASR/transducer_emformer/decode.py create mode 100755 egs/librispeech/ASR/transducer_emformer/streaming_decode.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 815e1c02a..2cb7a8cba 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass from typing import Dict, List, Optional @@ -482,8 +483,10 @@ def modified_beam_search( for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] @@ -590,8 +593,10 @@ def _deprecated_modified_beam_search( topk_hyp_indexes = topk_indexes // logits.size(-1) topk_token_indexes = topk_indexes % logits.size(-1) - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() for i in range(len(topk_hyp_indexes)): hyp = A[topk_hyp_indexes[i]] diff --git a/egs/librispeech/ASR/transducer_emformer/beam_search.py b/egs/librispeech/ASR/transducer_emformer/beam_search.py new file mode 120000 index 000000000..227d2247c --- /dev/null +++ b/egs/librispeech/ASR/transducer_emformer/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_emformer/decode.py b/egs/librispeech/ASR/transducer_emformer/decode.py new file mode 100755 index 000000000..c40b01dfa --- /dev/null +++ b/egs/librispeech/ASR/transducer_emformer/decode.py @@ -0,0 +1,549 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./transducer_emformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./transducer_emformer/exp \ + --max-duration 100 \ + --decoding-method greedy_search + +(2) beam search +./transducer_emformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./transducer_emformer/exp \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./transducer_emformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./transducer_emformer/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search +./transducer_emformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./transducer_emformer/exp \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--avg-last-n", + type=int, + default=0, + help="""If positive, --epoch and --avg are ignored and it + will use the last n checkpoints exp_dir/checkpoint-xxx.pt + where xxx is the number of processed batches while + saving that checkpoint. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer_emformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = model.device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 100 + else: + log_interval = 2 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.avg_last_n > 0: + filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + model.device = device + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_emformer/emformer.py b/egs/librispeech/ASR/transducer_emformer/emformer.py index a78c82c24..b3693d660 100644 --- a/egs/librispeech/ASR/transducer_emformer/emformer.py +++ b/egs/librispeech/ASR/transducer_emformer/emformer.py @@ -16,7 +16,7 @@ import math import warnings -from typing import Tuple +from typing import List, Optional, Tuple import torch import torch.nn as nn @@ -125,7 +125,9 @@ class Emformer(EncoderInterface): ) def forward( - self, x: torch.Tensor, x_lens: torch.Tensor + self, + x: torch.Tensor, + x_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -161,3 +163,38 @@ class Emformer(EncoderInterface): logits = self.encoder_output_layer(emformer_out) return logits, emformer_out_lens + + def streaming_forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + states: Optional[List[List[torch.Tensor]]] = None, + ): + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 2-D tensor of shap containing the number of valid frames for each + element in `x` before padding. + states: + Internal states of the model. + Returns: + Return a tuple containing 3 tensors: + - encoder_out, a 3-D tensor of shape (N, T, C) + - encoder_out_lens: a 1-D tensor of shape (N,) + - next_state, internal model states for the next chunk + """ + x = self.encoder_embed(x) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # Caution: We assume the subsampling factor is 4! + x_lens = ((x_lens - 1) // 2 - 1) // 2 + emformer_out, emformer_out_lens, states = self.model.infer( + x, x_lens, states + ) + + logits = self.encoder_output_layer(emformer_out) + + return logits, emformer_out_lens, states diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py new file mode 100755 index 000000000..8e836f3f7 --- /dev/null +++ b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Optional + +import kaldifeat +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--avg-last-n", + type=int, + default=0, + help="""If positive, --epoch and --avg are ignored and it + will use the last n checkpoints exp_dir/checkpoint-xxx.pt + where xxx is the number of processed batches while + saving that checkpoint. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer_emformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + add_model_arguments(parser) + + return parser + + +def get_feature_extractor( + params: AttributeDict, +) -> kaldifeat.Fbank: + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = params.device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = True + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + return kaldifeat.Fbank(opts) + + +def decode_one_utterance( + audio_samples: torch.Tensor, + model: nn.Module, + fbank: kaldifeat.Fbank, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +): + """Decode one utterance. + Args: + audio_samples: + A 1-D float32 tensor of shape (num_samples,) containing the normalized + audio samples. Normalized means the samples is in the range [-1, 1]. + model: + The RNN-T model. + feature_extractor: + The feature extractor. + params: + It is the return value of :func:`get_params`. + sp: + The BPE model. + """ + sample_rate = params.sample_rate + frame_shift = sample_rate * fbank.opts.frame_opts.frame_shift_ms / 1000 + + frame_shift = int(frame_shift) # number of samples + + # Note: We add 3 here because the subsampling method ((n-1)//2-1))//2 + # is not equal to n//4. We will switch to a subsampling method that + # satisfies n//4, where n is the number of input frames. + segment_length = (params.segment_length + 3) * frame_shift + + right_context_length = params.right_context_length * frame_shift + chunk_size = segment_length + right_context_length + + opts = fbank.opts.frame_opts + chunk_size += ( + (opts.frame_length_ms - opts.frame_shift_ms) / 1000 * sample_rate + ) + + chunk_size = int(chunk_size) + + states: Optional[List[List[torch.Tensor]]] = None + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + hyp = [blank_id] * context_size + + decoder_input = torch.tensor(hyp, device=device, dtype=torch.int64).reshape( + 1, context_size + ) + + decoder_out = model.decoder(decoder_input, need_pad=False) + + i = 0 + num_samples = audio_samples.size(0) + while i < num_samples: + # Note: The current approach of computing the features is not ideal + # since it re-computes the features for the right context. + chunk = audio_samples[i : i + chunk_size] # noqa + i += segment_length + if chunk.size(0) < chunk_size: + chunk = torch.nn.functional.pad( + chunk, pad=(0, chunk_size - chunk.size(0)) + ) + features = fbank(chunk) + feature_lens = torch.tensor([features.size(0)], device=params.device) + + features = features.unsqueeze(0) # (1, T, C) + + encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( + features, + feature_lens, + states, + ) + for t in range(encoder_out_lens.item()): + # fmt: off + current_encoder_out = encoder_out[0:1, t:t+1, :].unsqueeze(2) + # fmt: on + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) + # logits is (1, 1, 1, vocab_size) + y = logits.argmax().item() + if y == blank_id: + continue + + hyp.append(y) + + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device, dtype=torch.int64 + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + logging.info(f"Partial result:\n{sp.decode(hyp[context_size:])}") + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # Note: params.decoding_method is currently not used. + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + setup_logger(f"{params.res_dir}/log-streaming-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + params.device = device + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.avg_last_n > 0: + filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + model.device = device + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + + fbank = get_feature_extractor(params) + + for num, cut in enumerate(test_clean_cuts): + if num > 3: + break + logging.info("Processing {num}") + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + decode_one_utterance( + audio_samples=torch.from_numpy(audio).squeeze(0), + model=model, + fbank=fbank, + params=params, + sp=sp, + ) + + logging.info(f"The ground truth is:\n{cut.supervisions[0].text}") + time.sleep(3) # So that you can see the decoded results + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_emformer/test_emformer.py b/egs/librispeech/ASR/transducer_emformer/test_emformer.py index 67f209112..d8c7b37e2 100755 --- a/egs/librispeech/ASR/transducer_emformer/test_emformer.py +++ b/egs/librispeech/ASR/transducer_emformer/test_emformer.py @@ -41,11 +41,11 @@ def test_emformer(): d_model=512, nhead=8, dim_feedforward=2048, - num_encoder_layers=12, + num_encoder_layers=20, segment_length=16, left_context_length=120, right_context_length=4, - vgg_frontend=True, + vgg_frontend=False, ) x = torch.rand(N, T, C) diff --git a/egs/librispeech/ASR/transducer_emformer/train.py b/egs/librispeech/ASR/transducer_emformer/train.py index 175387e03..a0fd11efe 100755 --- a/egs/librispeech/ASR/transducer_emformer/train.py +++ b/egs/librispeech/ASR/transducer_emformer/train.py @@ -73,6 +73,64 @@ from icefall.utils import ( ) +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--attention-dim", + type=int, + default=512, + help="Attention dim for the Emformer", + ) + + parser.add_argument( + "--nhead", + type=int, + default=8, + help="Number of attention heads for the Emformer", + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=2048, + help="Feed-forward dimension for the Emformer", + ) + + parser.add_argument( + "--num-encoder-layers", + type=int, + default=12, + help="Number of encoder layers for the Emformer", + ) + + parser.add_argument( + "--left-context-length", + type=int, + default=120, + help="Number of frames for the left context in the Emformer", + ) + + parser.add_argument( + "--segment-length", + type=int, + default=16, + help="Number of frames for each segment in the Emformer", + ) + + parser.add_argument( + "--right-context-length", + type=int, + default=4, + help="Number of frames for right context in the Emformer", + ) + + parser.add_argument( + "--memory-size", + type=int, + default=0, + help="Number of entries in the memory for the Emformer", + ) + + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -222,6 +280,8 @@ def get_parser(): """, ) + add_model_arguments(parser) + return parser @@ -283,14 +343,7 @@ def get_params() -> AttributeDict: # parameters for Emformer "feature_dim": 80, "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, "vgg_frontend": False, - "left_context_length": 120, # 120 frames - "segment_length": 16, - "right_context_length": 4, # parameters for decoder "embedding_dim": 512, # parameters for Noam @@ -315,6 +368,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: left_context_length=params.left_context_length, segment_length=params.segment_length, right_context_length=params.right_context_length, + max_memory_size=params.memory_size, ) return encoder