diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py new file mode 100644 index 000000000..5afe4f859 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -0,0 +1,118 @@ +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +import k2 +import torch + +from icefall.utils import AttributeDict + + +class DecodeStream(object): + def __init__( + self, + params: AttributeDict, + decoding_graph: Optional[k2.Fsa] = None, + device: torch.device = torch.device("cpu"), + ) -> None: + """ + Args: + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + device: + The device to run this stream. + """ + if decoding_graph is not None: + assert device == decoding_graph.device + + # It contains a 2-D tensors representing the feature frames. + self.features: torch.Tensor = None + # how many frames are processed. (before subsampling). + self.num_processed_frames: int = 0 + self._done: bool = False + # The transcript of current utterance. + self.ground_truth: str = "" + # The decoding result (partial or final) of current utterance. + self.hyp: List = [] + + if params.decoding_method == "greedy_search": + self.hyp = [params.blank_id] * params.context_size + elif params.decoding_method == "fast_beam_search": + # feature_len is needed to get partial results. + self.feature_len: int = 0 + # The rnnt_decoding_stream for fast_beam_search. + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) + ) + else: + assert ( + False + ), f"Decoding method :{params.decoding_method} do not support" + + # The caches for streaming conformer + # It is a List containing two tensors, the first one is the cache for + # attention which has a shape of + # (num_encoder_layers, left_context, encoder_dim), + # the second one is the cache of conv_module which has a shape of + # (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). + self.states: List[torch.Tensor] = [ + torch.zeros( + ( + params.num_encoder_layers, + params.left_context, + params.encoder_dim, + ), + device=device, + ), + torch.zeros( + ( + params.num_encoder_layers, + params.cnn_module_kernel - 1, + params.encoder_dim, + ), + device=device, + ), + ] + + @property + def done(self) -> bool: + """Return True if all the features are processed.""" + return self._done + + def set_features( + self, + features: torch.Tensor, + ) -> None: + """Set features tensor of current utterance.""" + self.features = features + + def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: + """Consume chunk_size frames of features""" + ret_chunk_size = min( + self.features.size(0) - self.num_processed_frames, chunk_size + 3 + ) + ret_features = self.features[ + self.num_processed_frames : self.num_processed_frames + + ret_chunk_size, + :, + ] + self.num_processed_frames += chunk_size + + if self.num_processed_frames >= self.features.size(0): + self._done = True + + return ret_features, ret_chunk_size diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py new file mode 100755 index 000000000..f4a68e958 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -0,0 +1,698 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless2/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --decoding_method greedy_search \ + --num-decode-streams 200 +""" + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decode_stream import DecodeStream +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) +from icefall.decode import Nbest, one_best_decoding +from icefall.utils import get_texts +from kaldifeat import FbankOptions, Fbank +from lhotse import CutSet +from train import get_params, get_transducer_model +from torch.nn.utils.rnn import pad_sequence + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Support only greedy_search and fast_beam_search now. + """, + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="""How many left context can be seen in chunks when calculating attention. + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=True, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + return parser + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], +) -> List[List[int]]: + + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + T = encoder_out.size(1) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (N, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + + hyp_tokens = [] + for stream in streams: + hyp_tokens.append(stream.hyp) + return hyp_tokens + + +def fast_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + decoding_streams: k2.RnntDecodingStreams, +) -> List[List[int]]: + + B, T, C = encoder_out.shape + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + best_path = one_best_decoding(lattice) + hyp_tokens = get_texts(best_path) + return hyp_tokens + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + + rnnt_stream_list = [] + processed_feature_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames( + params.decode_chunk_size * params.subsampling_factor + ) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + if params.decoding_method == "fast_beam_search": + processed_feature_lens.append(stream.feature_len) + rnnt_stream_list.append(stream.rnnt_decoding_stream) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # if T is less than 7 there will be an error in time reduction layer, + # because we subsample features with ((x_len - 1) // 2 - 1) // 2 + if features.size(1) < 7: + feature_lens += 7 - features.size(1) + features = torch.cat( + [ + features, + torch.tensor( + LOG_EPS, dtype=features.dtype, device=device + ).expand( + features.size(0), 7 - features.size(1), features.size(2) + ), + ], + dim=1, + ) + + states = [ + torch.stack([x[0] for x in states], dim=2), + torch.stack([x[1] for x in states], dim=2), + ] + + # Note: states will be modified in streaming_forward. + encoder_out, encoder_out_lens = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + left_context=params.left_context, + ) + + if params.decoding_method == "greedy_search": + hyp_tokens = greedy_search(model, encoder_out, decode_streams) + elif params.decoding_method == "fast_beam_search": + config = k2.RnntDecodingConfig( + vocab_size=params.vocab_size, + decoder_history_len=params.context_size, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + processed_feature_lens = torch.tensor( + processed_feature_lens, device=device + ) + decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) + processed_lens = processed_feature_lens + encoder_out_lens + hyp_tokens = fast_beam_search( + model, encoder_out, processed_lens, decoding_streams + ) + else: + assert False + + states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = [states[0][i], states[1][i]] + if params.decoding_method == "fast_beam_search": + decode_streams[i].feature_len += encoder_out_lens[i] + decode_streams[i].hyp = hyp_tokens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 300 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + decode_stream = DecodeStream( + params=params, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + decode_stream.set_features(fbank(samples.to(device))) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params, model, sp, decode_streams + ) + for i in sorted(finished_streams, reverse=True): + hyp = decode_streams[i].hyp + if params.decoding_method == "greedy_search": + hyp = hyp[params.context_size :] + decode_results.append( + ( + decode_streams[i].ground_truth.split(), + sp.decode(hyp).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk(params, model, sp, decode_streams) + for i in sorted(finished_streams, reverse=True): + hyp = decode_streams[i].hyp + if params.decoding_method == "greedy_search": + hyp = hyp[params.context_size :] + decode_results.append( + ( + decode_streams[i].ground_truth.split(), + sp.decode(hyp).split(), + ) + ) + del decode_streams[i] + + key = "greedy_search" + if params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index f505cbc14..41020255b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -330,7 +330,7 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - attention_dim: Hidden dim for multi-head attention model. + - encoder_dim: Hidden dim for multi-head attention model. - num_decoder_layers: Number of decoder layer of transformer decoder. @@ -350,10 +350,11 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "attention_dim": 512, + "encoder_dim": 512, "nhead": 8, "dim_feedforward": 2048, "num_encoder_layers": 12, + "cnn_module_kernel": 31, "vgg_frontend": False, # parameters for decoder "embedding_dim": 512, @@ -372,10 +373,11 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: num_features=params.feature_dim, output_dim=params.vocab_size, subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, + d_model=params.encoder_dim, nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, + cnn_module_kernel=params.cnn_module_kernel, vgg_frontend=params.vgg_frontend, dynamic_chunk_training=params.dynamic_chunk_training, short_chunk_size=params.short_chunk_size, @@ -862,7 +864,7 @@ def run(rank, world_size, args): optimizer = Noam( model.parameters(), - model_size=params.attention_dim, + model_size=params.encoder_dim, factor=params.lr_factor, warm_step=params.warm_step, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index f80efe277..fb40bf5a5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -19,7 +19,7 @@ import logging import copy import math import warnings -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from encoder_interface import EncoderInterface @@ -104,11 +104,12 @@ class Conformer(EncoderInterface): self.encoder_layers = num_encoder_layers self.d_model = d_model + self.cnn_module_kernel = cnn_module_kernel + self.causal = causal self.dynamic_chunk_training = dynamic_chunk_training self.short_chunk_threshold = short_chunk_threshold self.short_chunk_size = short_chunk_size self.num_left_chunks = num_left_chunks - self.causal = causal self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -192,11 +193,11 @@ class Conformer(EncoderInterface): x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0, - states: Optional[Tensor] = None, + states: Optional[List[Tensor]] = None, chunk_size: int = 16, left_context: int = 64, simulate_streaming: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x: @@ -210,8 +211,11 @@ class Conformer(EncoderInterface): to turn modules on sequentially. states: The decode states for previous frames which contains the cached data. - It has a shape of (2, encoder_layers, left_context, batch, attention_dim), - states[0,...] is the attn_cache, states[1,...] is the conv_cache. + It has two elements, the first element is the attn_cache which has + a shape of (encoder_layers, left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (encoder_layers, cnn_module_kernel-1, batch, conv_dim). + Note: If not None, states will be modified in this function. chunk_size: The chunk size for decoding, this will be used to simulate streaming decoding using masking. @@ -234,7 +238,11 @@ class Conformer(EncoderInterface): # x: [N, T, C] # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 + + # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 + lengths = (((x_lens - 1) >> 1) - 1) >> 1 if not simulate_streaming: assert ( @@ -242,11 +250,14 @@ class Conformer(EncoderInterface): ), "Require cache when sending data in streaming mode" assert ( - states.shape == (2, self.encoder_layers, left_context, x.size(0), self.d_model) - ), f"""The shape of states MUST be equal to - (2, encoder_layers, left_context, batch, d_model) which is - {(2, self.encoder_layers, left_context, x.size(0), self.d_model)} - given {states.shape}.""" + len(states) == 2 and + states[0].shape == (self.encoder_layers, left_context, x.size(0), self.d_model) and + states[1].shape == (self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model) + ), f"""The length of states MUST be equal to 2, and the shape of + first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)}, + given {states[0].shape}. the shape of second element should be + {(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)}, + given {states[1].shape}.""" src_key_padding_mask = make_pad_mask(lengths + left_context) @@ -254,7 +265,7 @@ class Conformer(EncoderInterface): embed, pos_enc = self.encoder_pos(embed, left_context) embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - x, states = self.encoder( + x = self.encoder( embed, pos_enc, src_key_padding_mask=src_key_padding_mask, @@ -284,7 +295,7 @@ class Conformer(EncoderInterface): num_left_chunks=num_left_chunks, device=x.device ) - x, _ = self.encoder( + x = self.encoder( x, pos_emb, mask=mask, @@ -294,7 +305,7 @@ class Conformer(EncoderInterface): x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return x, lengths, states + return x, lengths class ConformerEncoderLayer(nn.Module): @@ -376,9 +387,9 @@ class ConformerEncoderLayer(nn.Module): src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, - states: Optional[Tensor] = None, + states: Optional[List[Tensor]] = None, left_context: int = 0, - ) -> Tuple[Tensor, Tensor]: + ) -> Tuple[Tensor]: """ Pass the input through the encoder layer. @@ -391,8 +402,11 @@ class ConformerEncoderLayer(nn.Module): bypass layers more frequently. states: The decode states for previous frames which contains the cached data. - It has a shape of (2, encoder_layers, left_context, batch, attention_dim), - states[0,...] is the attn_cache, states[1,...] is the conv_cache. + It has two elements, the first element is the attn_cache which has + a shape of (left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (cnn_module_kernel-1, batch, conv_dim). + Note: If not None, states will be modified in this function. left_context: left context (in frames) used during streaming decoding. this is used only in real streaming decoding, in other circumstances, it MUST be 0. @@ -425,9 +439,9 @@ class ConformerEncoderLayer(nn.Module): val = src if not self.training and states is not None: # src: [chunk_size, N, F] e.g. [8, 41, 512] - key = torch.cat([states[0, ...], src], dim=0) + key = torch.cat([states[0], src], dim=0) val = key - states[0, ...] = key[-left_context:, ...] + states[0] = key[-left_context:, ...] else: assert left_context == 0 @@ -445,15 +459,12 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - residual = src if not self.training and states is not None: - src = torch.cat([states[1, ...], src], dim=0) - states[1, ...] = src[-left_context:, ...] - - conv = self.conv_module(src) - conv = conv[-residual.size(0) :, :, :] # noqa: E203 - - src = residual + self.dropout(conv) + conv, conv_cache = self.conv_module(src, states[1]) + states[1] = conv_cache + else: + conv = self.conv_module(src) + src = src + self.dropout(conv) # feed forward module src = src + self.dropout(self.feed_forward(src)) @@ -463,7 +474,7 @@ class ConformerEncoderLayer(nn.Module): if alpha != 1.0: src = alpha * src + (1 - alpha) * src_orig - return src, states + return src class ConformerEncoder(nn.Module): @@ -495,9 +506,9 @@ class ConformerEncoder(nn.Module): mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, - states: Optional[Tensor] = None, + states: Optional[List[Tensor]] = None, left_context: int = 0, - ) -> Tuple[Tensor, Tensor]: + ) -> Tensor: r"""Pass the input through the encoder layers in turn. Args: @@ -509,8 +520,11 @@ class ConformerEncoder(nn.Module): bypass layers more frequently. states: The decode states for previous frames which contains the cached data. - It has a shape of (2, encoder_layers, left_context, batch, attention_dim), - states[0,...] is the attn_cache, states[1,...] is the conv_cache. + It has two elements, the first element is the attn_cache which has + a shape of (encoder_layers, left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (encoder_layers, cnn_module_kernel-1, batch, conv_dim). + Note: If not None, states will be modified in this function. left_context: left context (in frames) used during streaming decoding. this is used only in real streaming decoding, in other circumstances, it MUST be 0. @@ -532,19 +546,21 @@ class ConformerEncoder(nn.Module): assert left_context >= 0 for layer_index, mod in enumerate(self.layers): - output, cache = mod( + cache = None if states is None else [states[0][layer_index], states[1][layer_index]] + output = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, warmup=warmup, - states=None if states is None else states[:, layer_index, ...], + states=cache, left_context=left_context, ) if states is not None: - states[:, layer_index, ...] = cache + states[0][layer_index] = cache[0] + states[1][layer_index] = cache[1] - return output, states + return output class RelPositionalEncoding(torch.nn.Module): @@ -1180,14 +1196,23 @@ class ConvolutionModule(nn.Module): initial_scale=0.25, ) - def forward(self, x: Tensor) -> Tensor: + def forward( + self, + x: Tensor, + cache: Optional[Tensor] = None + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Compute convolution module. Args: x: Input tensor (#time, batch, channels). + cache: The cache of depthwise_conv, only used in real streaming + decoding. Returns: - Tensor: Output tensor (#time, batch, channels). + If cache is None return the output tensor (#time, batch, channels). + If cache is not None, return a tuple of Tensor, the first one is + the output tensor (#time, batch, channels), the second one is the + new cache for next chunk (#kernel_size - 1, batch, channels). """ # exchange the temporal dimension and the feature dimension @@ -1201,9 +1226,15 @@ class ConvolutionModule(nn.Module): # 1D Depthwise Conv if self.causal and self.lorder > 0: - # Make depthwise_conv causal by - # manualy padding self.lorder zeros to the left - x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + if cache is None: + # Make depthwise_conv causal by + # manualy padding self.lorder zeros to the left + x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + else: + assert not self.training, "Cache should be None in training time" + assert cache.size(0) == self.lorder + x = torch.cat([cache.permute(1, 2, 0), x], dim=2) + cache = x.permute(2, 0, 1)[-self.lorder:,...] x = self.depthwise_conv(x) x = self.deriv_balancer2(x) @@ -1211,7 +1242,7 @@ class ConvolutionModule(nn.Module): x = self.pointwise_conv2(x) # (batch, channel, time) - return x.permute(2, 0, 1) + return x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache) class Conv2dSubsampling(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode_stream.py new file mode 120000 index 000000000..30f264813 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode_stream.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/decode_stream.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py new file mode 100755 index 000000000..c95b22c0e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -0,0 +1,705 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless2/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --decoding_method greedy_search \ + --num-decode-streams 200 +""" + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decode_stream import DecodeStream +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) +from icefall.decode import Nbest, one_best_decoding +from icefall.utils import get_texts +from kaldifeat import FbankOptions, Fbank +from lhotse import CutSet +from train import get_params, get_transducer_model +from torch.nn.utils.rnn import pad_sequence + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Support only greedy_search and fast_beam_search now. + """, + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="""How many left context can be seen in chunks when calculating attention. + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=True, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + return parser + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], +) -> List[List[int]]: + + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + T = encoder_out.size(1) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (N, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # logging.info(f"decoder_out shape : {decoder_out.shape}") + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + decoder_out = model.joiner.decoder_proj(decoder_out) + + hyp_tokens = [] + for stream in streams: + hyp_tokens.append(stream.hyp) + return hyp_tokens + + +def fast_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + decoding_streams: k2.RnntDecodingStreams, +) -> List[List[int]]: + + B, T, C = encoder_out.shape + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + best_path = one_best_decoding(lattice) + hyp_tokens = get_texts(best_path) + return hyp_tokens + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + + rnnt_stream_list = [] + processed_feature_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames( + params.decode_chunk_size * params.subsampling_factor + ) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + if params.decoding_method == "fast_beam_search": + processed_feature_lens.append(stream.feature_len) + rnnt_stream_list.append(stream.rnnt_decoding_stream) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # if T is less than 7 there will be an error in time reduction layer, + # because we subsample features with ((x_len - 1) // 2 - 1) // 2 + if features.size(1) < 7: + feature_lens += 7 - features.size(1) + features = torch.cat( + [ + features, + torch.tensor( + LOG_EPS, dtype=features.dtype, device=device + ).expand( + features.size(0), 7 - features.size(1), features.size(2) + ), + ], + dim=1, + ) + + states = [ + torch.stack([x[0] for x in states], dim=2), + torch.stack([x[1] for x in states], dim=2), + ] + + # Note: states will be modified in streaming_forward. + encoder_out, encoder_out_lens = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + left_context=params.left_context, + ) + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + hyp_tokens = greedy_search(model, encoder_out, decode_streams) + elif params.decoding_method == "fast_beam_search": + config = k2.RnntDecodingConfig( + vocab_size=params.vocab_size, + decoder_history_len=params.context_size, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + processed_feature_lens = torch.tensor( + processed_feature_lens, device=device + ) + decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) + processed_lens = processed_feature_lens + encoder_out_lens + hyp_tokens = fast_beam_search( + model, encoder_out, processed_lens, decoding_streams + ) + else: + assert False + + states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = [states[0][i], states[1][i]] + if params.decoding_method == "fast_beam_search": + decode_streams[i].feature_len += encoder_out_lens[i] + decode_streams[i].hyp = hyp_tokens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 300 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + decode_stream = DecodeStream( + params=params, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + decode_stream.set_features(fbank(samples.to(device))) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params, model, sp, decode_streams + ) + for i in sorted(finished_streams, reverse=True): + hyp = decode_streams[i].hyp + if params.decoding_method == "greedy_search": + hyp = hyp[params.context_size :] + decode_results.append( + ( + decode_streams[i].ground_truth.split(), + sp.decode(hyp).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk(params, model, sp, decode_streams) + for i in sorted(finished_streams, reverse=True): + hyp = decode_streams[i].hyp + if params.decoding_method == "greedy_search": + hyp = hyp[params.context_size :] + decode_results.append( + ( + decode_streams[i].ground_truth.split(), + sp.decode(hyp).split(), + ) + ) + del decode_streams[i] + + key = "greedy_search" + if params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index d51a96d8e..3bab3ba98 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -393,6 +393,7 @@ def get_params() -> AttributeDict: "nhead": 8, "dim_feedforward": 2048, "num_encoder_layers": 12, + "cnn_module_kernel": 31, # parameters for decoder "decoder_dim": 512, # parameters for joiner @@ -415,6 +416,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, + cnn_module_kernel=params.cnn_module_kernel, dynamic_chunk_training=params.dynamic_chunk_training, short_chunk_size=params.short_chunk_size, num_left_chunks=params.num_left_chunks, diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 8936af2ab..5cf869119 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -18,7 +18,7 @@ import copy import math import warnings -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from torch import Tensor, nn @@ -97,11 +97,13 @@ class Conformer(Transformer): self.encoder_layers = num_encoder_layers self.d_model = d_model + self.cnn_module_kernel = cnn_module_kernel + self.causal = causal + self.dynamic_chunk_training = dynamic_chunk_training self.short_chunk_threshold = short_chunk_threshold self.short_chunk_size = short_chunk_size self.num_left_chunks = num_left_chunks - self.causal = causal self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -188,11 +190,11 @@ class Conformer(Transformer): self, x: torch.Tensor, x_lens: torch.Tensor, - states: Optional[torch.Tensor] = None, + states: Optional[List[torch.Tensor]] = None, chunk_size: int = 16, left_context: int = 64, simulate_streaming: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x: @@ -202,8 +204,11 @@ class Conformer(Transformer): `x` before padding. states: The decode states for previous frames which contains the cached data. - It has a shape of (2, encoder_layers, left_context, batch, attention_dim), - states[0,...] is the attn_cache, states[1,...] is the conv_cache. + It has two elements, the first element is the attn_cache which has + a shape of (encoder_layers, left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (encoder_layers, cnn_module_kernel-1, batch, conv_dim). + Note: If not None, states will be modified in this function. chunk_size: The chunk size for decoding, this will be used to simulate streaming decoding using masking. @@ -226,7 +231,11 @@ class Conformer(Transformer): # x: [N, T, C] # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 + + # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 + lengths = (((x_lens - 1) >> 1) - 1) >> 1 if not simulate_streaming: assert ( @@ -234,11 +243,14 @@ class Conformer(Transformer): ), "Require cache when sending data in streaming mode" assert ( - states.shape == (2, self.encoder_layers, left_context, x.size(0), self.d_model) - ), f"""The shape of states MUST be equal to - (2, encoder_layers, left_context, batch, d_model) which is - {(2, self.encoder_layers, left_context, x.size(0), self.d_model)} - given {states.shape}.""" + len(states) == 2 and + states[0].shape == (self.encoder_layers, left_context, x.size(0), self.d_model) and + states[1].shape == (self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model) + ), f"""The length of states MUST be equal to 2, and the shape of + first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)}, + given {states[0].shape}. the shape of second element should be + {(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)}, + given {states[1].shape}.""" src_key_padding_mask = make_pad_mask(lengths + left_context) @@ -246,7 +258,7 @@ class Conformer(Transformer): embed, pos_enc = self.encoder_pos(embed, left_context) embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - x, states = self.encoder( + x = self.encoder( embed, pos_enc, src_key_padding_mask=src_key_padding_mask, @@ -275,7 +287,7 @@ class Conformer(Transformer): num_left_chunks=num_left_chunks, device=x.device ) - x, _ = self.encoder( + x = self.encoder( x, pos_emb, mask=mask, @@ -288,7 +300,7 @@ class Conformer(Transformer): logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return logits, lengths, states + return logits, lengths class ConformerEncoderLayer(nn.Module): @@ -369,9 +381,9 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - states: Optional[Tensor] = None, + states: Optional[List[Tensor]] = None, left_context: int = 0, - ) -> Tuple[Tensor, Tensor]: + ) -> Tensor: """ Pass the input through the encoder layer. @@ -380,9 +392,13 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Positional embedding tensor (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - states: The decode states for previous frames which contains the cached data. - It has a shape of (2, left_context, batch, attention_dim), - states[0,...] is the attn_cache, states[1,...] is the conv_cache. + states: + The decode states for previous frames which contains the cached data. + It has two elements, the first element is the attn_cache which has + a shape of (encoder_layers, left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (encoder_layers, cnn_module_kernel-1, batch, conv_dim). + Note: If not None, states will be modified in this function. left_context: left context (in frames) used during streaming decoding. this is used only in real streaming decoding, in other circumstances, it MUST be 0. @@ -413,9 +429,9 @@ class ConformerEncoderLayer(nn.Module): val = src if not self.training and states is not None: # src: [chunk_size, N, F] e.g. [8, 41, 512] - key = torch.cat([states[0, ...], src], dim=0) + key = torch.cat([states[0], src], dim=0) val = key - states[0, ...] = key[-left_context:, ...] + states[0] = key[-left_context:, ...] else: assert left_context == 0 @@ -438,13 +454,12 @@ class ConformerEncoderLayer(nn.Module): src = self.norm_conv(src) if not self.training and states is not None: - src = torch.cat([states[1, ...], src], dim=0) - states[1, ...] = src[-left_context:, ...] - - src = self.conv_module(src) - src = src[-residual.size(0) :, :, :] # noqa: E203 - + src, conv_cache = self.conv_module(src, states[1]) + states[1] = conv_cache + else: + src = self.conv_module(src) src = residual + self.dropout(src) + if not self.normalize_before: src = self.norm_conv(src) @@ -459,7 +474,7 @@ class ConformerEncoderLayer(nn.Module): if self.normalize_before: src = self.norm_final(src) - return src, states + return src class ConformerEncoder(nn.Module): @@ -490,9 +505,9 @@ class ConformerEncoder(nn.Module): pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - states: Optional[Tensor] = None, + states: Optional[List[Tensor]] = None, left_context: int = 0, - ) -> Tuple[Tensor, Tensor]: + ) -> Tensor: r"""Pass the input through the encoder layers in turn. Args: @@ -500,9 +515,13 @@ class ConformerEncoder(nn.Module): pos_emb: Positional embedding tensor (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - states: The decode states for previous frames which contains the cached data. - It has a shape of (2, encoder_layers, left_context, batch, attention_dim), - states[0,...] is the attn_cache, states[1,...] is the conv_cache. + states: + The decode states for previous frames which contains the cached data. + It has two elements, the first element is the attn_cache which has + a shape of (encoder_layers, left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (encoder_layers, cnn_module_kernel-1, batch, conv_dim). + Note: If not None, states will be modified in this function. left_context: left context (in frames) used during streaming decoding. this is used only in real streaming decoding, in other circumstances, it MUST be 0. @@ -525,18 +544,20 @@ class ConformerEncoder(nn.Module): assert left_context >= 0 for layer_index, mod in enumerate(self.layers): - output, cache = mod( + cache = None if states is None else [states[0][layer_index], states[1][layer_index]] + output = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, - states=None if states is None else states[:,layer_index, ...], + states=cache, left_context=left_context, ) if states is not None: - states[:, layer_index, ...] = cache + states[0][layer_index] = cache[0] + states[1][layer_index] = cache[1] - return output, states + return output class RelPositionalEncoding(torch.nn.Module): @@ -1146,7 +1167,11 @@ class ConvolutionModule(nn.Module): ) self.activation = Swish() - def forward(self, x: Tensor) -> Tensor: + def forward( + self, + x: Tensor, + cache: Optional[Tensor] = None + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Compute convolution module. Args: @@ -1165,9 +1190,15 @@ class ConvolutionModule(nn.Module): # 1D Depthwise Conv if self.causal and self.lorder > 0: - # Make depthwise_conv causal by - # manualy padding self.lorder zeros to the left - x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + if cache is None: + # Make depthwise_conv causal by + # manualy padding self.lorder zeros to the left + x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + else: + assert not self.training, "Cache should be None in training time" + assert cache.size(0) == self.lorder + x = torch.cat([cache.permute(1, 2, 0), x], dim=2) + cache = x.permute(2, 0, 1)[-self.lorder:,...] x = self.depthwise_conv(x) # x is (batch, channels, time) @@ -1179,7 +1210,7 @@ class ConvolutionModule(nn.Module): x = self.pointwise_conv2(x) # (batch, channel, time) - return x.permute(2, 0, 1) + return x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache) class Swish(torch.nn.Module):