diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index 9263ac449..9bbbf57ee 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -26,11 +26,15 @@ class DecodeStream(object): def __init__( self, params: AttributeDict, + initial_states: List[torch.Tensor], decoding_graph: Optional[k2.Fsa] = None, device: torch.device = torch.device("cpu"), ) -> None: """ Args: + initial_states: + Initial decode states of the model, e.g. the return value of + `get_init_state` in conformer.py decoding_graph: Decoding graph used for decoding, may be a TrivialGraph or a HLG. device: @@ -41,6 +45,8 @@ class DecodeStream(object): self.params = params + self.states = initial_states + # It contains a 2-D tensors representing the feature frames. self.features: torch.Tensor = None # how many frames are processed. (before subsampling). @@ -56,7 +62,6 @@ class DecodeStream(object): if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size elif params.decoding_method == "fast_beam_search": - # feature_len is needed to get partial results. # The rnnt_decoding_stream for fast_beam_search. self.rnnt_decoding_stream: k2.RnntDecodingStream = ( k2.RnntDecodingStream(decoding_graph) @@ -66,31 +71,6 @@ class DecodeStream(object): False ), f"Decoding method :{params.decoding_method} do not support" - # The caches for streaming conformer - # It is a List containing two tensors, the first one is the cache for - # attention which has a shape of - # (num_encoder_layers, left_context, encoder_dim), - # the second one is the cache of conv_module which has a shape of - # (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). - self.states: List[torch.Tensor] = [ - torch.zeros( - ( - params.num_encoder_layers, - params.left_context, - params.encoder_dim, - ), - device=device, - ), - torch.zeros( - ( - params.num_encoder_layers, - params.cnn_module_kernel - 1, - params.encoder_dim, - ), - device=device, - ), - ] - @property def done(self) -> bool: """Return True if all the features are processed.""" @@ -105,6 +85,8 @@ class DecodeStream(object): def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: """Consume chunk_size frames of features""" + # plus 3 here because we subsampling features with + # lengths = ((x_lens - 1) // 2 - 1) // 2 ret_chunk_size = min( self.features.size(0) - self.num_processed_frames, chunk_size + 3 ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index 62f602abc..1f3fa79b7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -378,7 +378,7 @@ def decode_one_chunk( ] # Note: states will be modified in streaming_forward. - encoder_out, encoder_out_lens = model.encoder.streaming_forward( + encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( x=features, x_lens=feature_lens, states=states, @@ -462,10 +462,12 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] + initial_states = model.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( params=params, + initial_states=initial_states, decoding_graph=decoding_graph, device=device, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 84fad00ff..bcb78414f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -353,7 +353,6 @@ def get_params() -> AttributeDict: "nhead": 8, "dim_feedforward": 2048, "num_encoder_layers": 12, - "cnn_module_kernel": 31, "vgg_frontend": False, # parameters for decoder "embedding_dim": 512, @@ -376,7 +375,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, - cnn_module_kernel=params.cnn_module_kernel, vgg_frontend=params.vgg_frontend, dynamic_chunk_training=params.dynamic_chunk_training, short_chunk_size=params.short_chunk_size, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index a0872b934..bde13f32d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -122,6 +122,7 @@ class Conformer(EncoderInterface): causal, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self._init_state = torch.jit.Attribute([], List[torch.Tensor]) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 @@ -194,6 +195,55 @@ class Conformer(EncoderInterface): x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) return x, lengths + @torch.jit.export + def get_init_state( + self, left_context: int, device: torch.device + ) -> List[torch.Tensor]: + """Return the initial cache state of the model. + + Args: + left_context: The left context size (in frames after subsampling). + + Returns: + Return the initial state of the model, it is a list containing two + tensors, the first one is the cache for attentions which has a shape + of (num_encoder_layers, left_context, encoder_dim), the second one + is the cache of conv_modules which has a shape of + (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). + + NOTE: the returned tensors are on the given device. + """ + if ( + len(self._init_state) == 2 + and self._init_state[0].size(1) == left_context + ): + # Note: It is OK to share the init state as it is + # not going to be modified by the model + return self._init_state + + init_states: List[torch.Tensor] = [ + torch.zeros( + ( + self.encoder_layers, + left_context, + self.d_model, + ), + device=device, + ), + torch.zeros( + ( + self.encoder_layers, + self.cnn_module_kernel - 1, + self.d_model, + ), + device=device, + ), + ] + + self._init_state = init_states + + return init_states + @torch.jit.export def streaming_forward( self, @@ -206,7 +256,7 @@ class Conformer(EncoderInterface): right_context: int = 4, simulate_streaming: bool = False, processed_lens: Optional[Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: """ Args: x: @@ -296,7 +346,7 @@ class Conformer(EncoderInterface): embed, pos_enc = self.encoder_pos(embed, left_context) embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - x = self.encoder.chunk_forward( + x, states = self.encoder.chunk_forward( embed, pos_enc, src_key_padding_mask=src_key_padding_mask, @@ -338,7 +388,7 @@ class Conformer(EncoderInterface): x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return x, lengths + return x, lengths, states class ConformerEncoderLayer(nn.Module): @@ -490,7 +540,7 @@ class ConformerEncoderLayer(nn.Module): warmup: float = 1.0, left_context: int = 0, right_context: int = 0, - ) -> Tensor: + ) -> Tuple[Tensor, List[Tensor]]: """ Pass the input through the encoder layer. @@ -527,6 +577,12 @@ class ConformerEncoderLayer(nn.Module): # macaron style feed forward module src = src + self.dropout(self.feed_forward_macaron(src)) + # We put the attention cache this level (i.e. before linear transformation) + # to save memory consumption, when decoding in streaming fashion, the + # batch size would be thousands (for 32GB machine), if we cache key & val + # separately, it needs extra several GB memory. + # TODO(WeiKang): Move cache to self_attn level (i.e. cache key & val + # separately) if needed. key = torch.cat([states[0], src], dim=0) val = key if right_context > 0: @@ -560,7 +616,7 @@ class ConformerEncoderLayer(nn.Module): src = self.norm_final(self.balancer(src)) - return src + return src, states class ConformerEncoder(nn.Module): @@ -635,7 +691,7 @@ class ConformerEncoder(nn.Module): warmup: float = 1.0, left_context: int = 0, right_context: int = 0, - ) -> Tensor: + ) -> Tuple[Tensor, List[Tensor]]: r"""Pass the input through the encoder layers in turn. Args: @@ -678,7 +734,7 @@ class ConformerEncoder(nn.Module): for layer_index, mod in enumerate(self.layers): cache = [states[0][layer_index], states[1][layer_index]] - output = mod.chunk_forward( + output, cache = mod.chunk_forward( output, pos_emb, states=cache, @@ -691,7 +747,7 @@ class ConformerEncoder(nn.Module): states[0][layer_index] = cache[0] states[1][layer_index] = cache[1] - return output + return output, states class RelPositionalEncoding(torch.nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index 81643e3c4..0781dfb85 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -357,7 +357,8 @@ def decode_one_chunk( for stream in decode_streams: feat, feat_len = stream.get_feature_frames( - (params.decode_chunk_size + 2) * params.subsampling_factor + (params.decode_chunk_size + 2 + params.right_context) + * params.subsampling_factor ) features.append(feat) feature_lens.append(feat_len) @@ -394,8 +395,7 @@ def decode_one_chunk( ] processed_feature_lens = torch.tensor(processed_feature_lens, device=device) - # Note: states will be modified in streaming_forward. - encoder_out, encoder_out_lens = model.encoder.streaming_forward( + encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( x=features, x_lens=feature_lens, states=states, @@ -475,15 +475,17 @@ def decode_dataset( opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 - log_interval = 300 + log_interval = 50 decode_results = [] # Contain decode streams currently running. decode_streams = [] + initial_states = model.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( params=params, + initial_states=initial_states, decoding_graph=decoding_graph, device=device, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index cee271706..d445713fe 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -393,7 +393,6 @@ def get_params() -> AttributeDict: "nhead": 8, "dim_feedforward": 2048, "num_encoder_layers": 12, - "cnn_module_kernel": 31, # parameters for decoder "decoder_dim": 512, # parameters for joiner @@ -416,7 +415,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, - cnn_module_kernel=params.cnn_module_kernel, dynamic_chunk_training=params.dynamic_chunk_training, short_chunk_size=params.short_chunk_size, num_left_chunks=params.num_left_chunks, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 5b3dce853..43af59761 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -58,6 +58,7 @@ Usage: import argparse import logging +import math from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -87,9 +88,12 @@ from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, + str2bool, write_error_stats, ) +LOG_EPS = math.log(1e-10) + def get_parser(): parser = argparse.ArgumentParser( @@ -219,6 +223,70 @@ def get_parser(): Used only when the decoding_method is fast_beam_search_nbest_oracle. """, ) + + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="""How many left context can be seen in chunks when calculating attention. + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--right-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) return parser @@ -268,9 +336,27 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, ) + + if params.simulate_streaming: + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + states=[], + chunk_size=params.right_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + hyps = [] if params.decoding_method == "fast_beam_search": @@ -509,6 +595,10 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + if params.decoding_method == "fast_beam_search": params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -544,6 +634,11 @@ def main(): params.unk_id = sp.unk_id() params.vocab_size = sp.get_piece_size() + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + logging.info(params) logging.info("About to create model") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode_stream.py new file mode 120000 index 000000000..30f264813 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode_stream.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/decode_stream.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py new file mode 100755 index 000000000..82d7d024d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -0,0 +1,721 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless2/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --decoding_method greedy_search \ + --num-decode-streams 200 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from librispeech import LibriSpeech +from torch.nn.utils.rnn import pad_sequence +from train import get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import one_best_decoding +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Support only greedy_search and fast_beam_search now. + """, + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="""How many left context can be seen in chunks when calculating attention. + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=True, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--right-context", + type=int, + default=4, + help="right context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + return parser + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], +) -> List[List[int]]: + + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + T = encoder_out.size(1) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (N, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # logging.info(f"decoder_out shape : {decoder_out.shape}") + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + decoder_out = model.joiner.decoder_proj(decoder_out) + + hyp_tokens = [] + for stream in streams: + hyp_tokens.append(stream.hyp) + return hyp_tokens + + +def fast_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + decoding_streams: k2.RnntDecodingStreams, +) -> List[List[int]]: + + B, T, C = encoder_out.shape + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + best_path = one_best_decoding(lattice) + hyp_tokens = get_texts(best_path) + return hyp_tokens + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + + rnnt_stream_list = [] + processed_feature_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames( + (params.decode_chunk_size + 2 + params.right_context) + * params.subsampling_factor + ) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_feature_lens.append(stream.feature_len) + if params.decoding_method == "fast_beam_search": + rnnt_stream_list.append(stream.rnnt_decoding_stream) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # if T is less than 7 there will be an error in time reduction layer, + # because we subsample features with ((x_len - 1) // 2 - 1) // 2 + tail_length = 15 + params.right_context * params.subsampling_factor + if features.size(1) < tail_length: + feature_lens += tail_length - features.size(1) + features = torch.cat( + [ + features, + torch.tensor( + LOG_EPS, dtype=features.dtype, device=device + ).expand( + features.size(0), + tail_length - features.size(1), + features.size(2), + ), + ], + dim=1, + ) + + states = [ + torch.stack([x[0] for x in states], dim=2), + torch.stack([x[1] for x in states], dim=2), + ] + processed_feature_lens = torch.tensor(processed_feature_lens, device=device) + + encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + left_context=params.left_context, + right_context=params.right_context, + processed_lens=processed_feature_lens, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + hyp_tokens = greedy_search(model, encoder_out, decode_streams) + elif params.decoding_method == "fast_beam_search": + config = k2.RnntDecodingConfig( + vocab_size=params.vocab_size, + decoder_history_len=params.context_size, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) + processed_lens = processed_feature_lens + encoder_out_lens + hyp_tokens = fast_beam_search( + model, encoder_out, processed_lens, decoding_streams + ) + else: + assert False + + states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = [states[0][i], states[1][i]] + decode_streams[i].feature_len += encoder_out_lens[i] + if params.decoding_method == "fast_beam_search": + decode_streams[i].hyp = hyp_tokens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + initial_states = model.get_init_state(params.left_context, device=device) + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + decode_stream = DecodeStream( + params=params, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params, model, sp, decode_streams + ) + for i in sorted(finished_streams, reverse=True): + hyp = decode_streams[i].hyp + if params.decoding_method == "greedy_search": + hyp = hyp[params.context_size :] # noqa + decode_results.append( + ( + decode_streams[i].ground_truth.split(), + sp.decode(hyp).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk(params, model, sp, decode_streams) + for i in sorted(finished_streams, reverse=True): + hyp = decode_streams[i].hyp + if params.decoding_method == "greedy_search": + hyp = hyp[params.context_size :] # noqa + decode_results.append( + ( + decode_streams[i].ground_truth.split(), + sp.decode(hyp).split(), + ) + ) + del decode_streams[i] + + key = "greedy_search" + if params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=sorted(results)) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + params.suffix += f"-right-context-{params.right_context}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeech(params.manifest_dir) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index f5a25a226..a8b8c7349 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -286,6 +286,59 @@ def get_parser(): help="The probability to select a batch from the GigaSpeech dataset", ) + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value to penalize symbol delay, this may be + needed when training with time masking, to avoid the time masking + encouraging the network to delay symbols. + """, + ) + + parser.add_argument( + "--return-sym-delay", + type=str2bool, + default=False, + help="""Whether to return `sym_delay` during training, this is a stat + to measure symbols emission delay, especially for time masking training. + """, + ) + return parser @@ -372,6 +425,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, + dynamic_chunk_training=params.dynamic_chunk_training, + short_chunk_size=params.short_chunk_size, + num_left_chunks=params.num_left_chunks, + causal=params.causal_convolution, ) return encoder @@ -905,6 +962,15 @@ def run(rank, world_size, args): params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.dynamic_chunk_training: + assert ( + params.causal_convolution + ), "dynamic_chunk_training requires causal convolution" + else: + assert ( + params.delay_penalty == 0.0 + ), "delay_penalty is intended for dynamic_chunk_training" + logging.info(params) logging.info("About to create model") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index db8a80f46..9d340a20e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -350,6 +350,7 @@ def decode_one_batch( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, + states=[], chunk_size=params.right_chunk_size, left_context=params.left_context, simulate_streaming=True, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode_stream.py new file mode 120000 index 000000000..30f264813 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode_stream.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/decode_stream.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py new file mode 100755 index 000000000..b9fdaa68e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -0,0 +1,783 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless2/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --decoding_method greedy_search \ + --num-decode-streams 200 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from torch.nn.utils.rnn import pad_sequence +from train import get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import one_best_decoding +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Support only greedy_search and fast_beam_search now. + """, + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="""How many left context can be seen in chunks when calculating attention. + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=True, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--right-context", + type=int, + default=4, + help="right context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + return parser + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], +) -> List[List[int]]: + + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + T = encoder_out.size(1) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (N, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # logging.info(f"decoder_out shape : {decoder_out.shape}") + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + decoder_out = model.joiner.decoder_proj(decoder_out) + + hyp_tokens = [] + for stream in streams: + hyp_tokens.append(stream.hyp) + return hyp_tokens + + +def fast_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + decoding_streams: k2.RnntDecodingStreams, +) -> List[List[int]]: + + B, T, C = encoder_out.shape + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + best_path = one_best_decoding(lattice) + hyp_tokens = get_texts(best_path) + return hyp_tokens + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + + rnnt_stream_list = [] + processed_feature_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames( + (params.decode_chunk_size + 2 + params.right_context) + * params.subsampling_factor + ) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_feature_lens.append(stream.feature_len) + if params.decoding_method == "fast_beam_search": + rnnt_stream_list.append(stream.rnnt_decoding_stream) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # if T is less than 7 there will be an error in time reduction layer, + # because we subsample features with ((x_len - 1) // 2 - 1) // 2 + tail_length = 15 + params.right_context * params.subsampling_factor + if features.size(1) < tail_length: + feature_lens += tail_length - features.size(1) + features = torch.cat( + [ + features, + torch.tensor( + LOG_EPS, dtype=features.dtype, device=device + ).expand( + features.size(0), + tail_length - features.size(1), + features.size(2), + ), + ], + dim=1, + ) + + states = [ + torch.stack([x[0] for x in states], dim=2), + torch.stack([x[1] for x in states], dim=2), + ] + processed_feature_lens = torch.tensor(processed_feature_lens, device=device) + + encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + left_context=params.left_context, + right_context=params.right_context, + processed_lens=processed_feature_lens, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + hyp_tokens = greedy_search(model, encoder_out, decode_streams) + elif params.decoding_method == "fast_beam_search": + config = k2.RnntDecodingConfig( + vocab_size=params.vocab_size, + decoder_history_len=params.context_size, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) + processed_lens = processed_feature_lens + encoder_out_lens + hyp_tokens = fast_beam_search( + model, encoder_out, processed_lens, decoding_streams + ) + else: + assert False + + states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = [states[0][i], states[1][i]] + decode_streams[i].feature_len += encoder_out_lens[i] + if params.decoding_method == "fast_beam_search": + decode_streams[i].hyp = hyp_tokens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + initial_states = model.get_init_state(params.left_context, device=device) + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + decode_stream = DecodeStream( + params=params, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params, model, sp, decode_streams + ) + for i in sorted(finished_streams, reverse=True): + hyp = decode_streams[i].hyp + if params.decoding_method == "greedy_search": + hyp = hyp[params.context_size :] # noqa + decode_results.append( + ( + decode_streams[i].ground_truth.split(), + sp.decode(hyp).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk(params, model, sp, decode_streams) + for i in sorted(finished_streams, reverse=True): + hyp = decode_streams[i].hyp + if params.decoding_method == "greedy_search": + hyp = hyp[params.context_size :] # noqa + decode_results.append( + ( + decode_streams[i].ground_truth.split(), + sp.decode(hyp).split(), + ) + ) + del decode_streams[i] + + key = "greedy_search" + if params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=sorted(results)) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + params.suffix += f"-right-context-{params.right_context}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index ec33565ba..89e424084 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -125,6 +125,8 @@ class Conformer(Transformer): # and throws an error without this change. self.after_norm = identity + self._init_state = torch.jit.Attribute([], List[torch.Tensor]) + def forward( self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -189,6 +191,55 @@ class Conformer(Transformer): return logits, lengths + @torch.jit.export + def get_init_state( + self, left_context: int, device: torch.device + ) -> List[torch.Tensor]: + """Return the initial cache state of the model. + + Args: + left_context: The left context size (in frames after subsampling). + + Returns: + Return the initial state of the model, it is a list containing two + tensors, the first one is the cache for attentions which has a shape + of (num_encoder_layers, left_context, encoder_dim), the second one + is the cache of conv_modules which has a shape of + (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). + + NOTE: the returned tensors are on the given device. + """ + if ( + len(self._init_state) == 2 + and self._init_state[0].size(1) == left_context + ): + # Note: It is OK to share the init state as it is + # not going to be modified by the model + return self._init_state + + init_states: List[torch.Tensor] = [ + torch.zeros( + ( + self.encoder_layers, + left_context, + self.d_model, + ), + device=device, + ), + torch.zeros( + ( + self.encoder_layers, + self.cnn_module_kernel - 1, + self.d_model, + ), + device=device, + ), + ] + + self._init_state = init_states + + return init_states + @torch.jit.export def streaming_forward( self, @@ -198,7 +249,7 @@ class Conformer(Transformer): chunk_size: int = 16, left_context: int = 64, simulate_streaming: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: """ Args: x: @@ -229,7 +280,7 @@ class Conformer(Transformer): - logits, its shape is (batch_size, output_seq_len, output_dim) - logit_lens, a tensor of shape (batch_size,) containing the number of frames in `logits` before padding. - - decode_states, the updated DecodeStates including the information + - states, the updated states(i.e. caches) including the information of current chunk. """ @@ -265,7 +316,7 @@ class Conformer(Transformer): embed, pos_enc = self.encoder_pos(embed, left_context) embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - x = self.encoder.chunk_forward( + x, states = self.encoder.chunk_forward( embed, pos_enc, src_key_padding_mask=src_key_padding_mask, @@ -304,7 +355,7 @@ class Conformer(Transformer): logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return logits, lengths + return logits, lengths, states class ConformerEncoderLayer(nn.Module): @@ -461,7 +512,7 @@ class ConformerEncoderLayer(nn.Module): src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, left_context: int = 0, - ) -> Tensor: + ) -> Tuple[Tensor, List[Tensor]]: """ Pass the input through the encoder layer. @@ -471,9 +522,9 @@ class ConformerEncoderLayer(nn.Module): states: The decode states for previous frames which contains the cached data. It has two elements, the first element is the attn_cache which has - a shape of (encoder_layers, left_context, batch, attention_dim), + a shape of (left_context, batch, attention_dim), the second element is the conv_cache which has a shape of - (encoder_layers, cnn_module_kernel-1, batch, conv_dim). + (cnn_module_kernel-1, batch, conv_dim). Note: states will be modified in this function. src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). @@ -503,6 +554,12 @@ class ConformerEncoderLayer(nn.Module): if self.normalize_before: src = self.norm_mha(src) + # We put the attention cache this level (i.e. before linear transformation) + # to save memory consumption, when decoding in streaming fashion, the + # batch size would be thousands (for 32GB machine), if we cache key & val + # separately, it needs extra several GB memory. + # TODO(WeiKang): Move cache to self_attn level (i.e. cache key & val + # separately) if needed. key = torch.cat([states[0], src], dim=0) val = key states[0] = key[-left_context:, ...] @@ -543,7 +600,7 @@ class ConformerEncoderLayer(nn.Module): if self.normalize_before: src = self.norm_final(src) - return src + return src, states class ConformerEncoder(nn.Module): @@ -612,7 +669,7 @@ class ConformerEncoder(nn.Module): mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, left_context: int = 0, - ) -> Tensor: + ) -> Tuple[Tensor, List[Tensor]]: r"""Pass the input through the encoder layers in turn. Args: @@ -643,7 +700,7 @@ class ConformerEncoder(nn.Module): for layer_index, mod in enumerate(self.layers): cache = [states[0][layer_index], states[1][layer_index]] - output = mod.chunk_forward( + output, cache = mod.chunk_forward( output, pos_emb, states=cache, @@ -654,7 +711,7 @@ class ConformerEncoder(nn.Module): states[0][layer_index] = cache[0] states[1][layer_index] = cache[1] - return output + return output, states class RelPositionalEncoding(torch.nn.Module):