From 1d036a6a5a71a5136dd20da98fe615a2acc38f6e Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 18 Jul 2022 00:03:51 +0800 Subject: [PATCH] Support streaming decoding --- .../pruned_transducer_stateless5/conformer.py | 279 ++++++- .../decode_stream.py | 1 + .../pruned_transducer_stateless5/export.py | 12 + .../streaming_decode.py | 750 ++++++++++++++++++ 4 files changed, 1026 insertions(+), 16 deletions(-) create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless5/decode_stream.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 8872ade03..25a447de6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -132,6 +132,7 @@ class Conformer(EncoderInterface): ) ), ) + self._init_state: List[torch.Tensor] = [torch.empty(0)] def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 @@ -197,6 +198,52 @@ class Conformer(EncoderInterface): return x, lengths + @torch.jit.export + def get_init_state( + self, left_context: int, device: torch.device + ) -> List[torch.Tensor]: + """Return the initial cache state of the model. + Args: + left_context: The left context size (in frames after subsampling). + Returns: + Return the initial state of the model, it is a list containing two + tensors, the first one is the cache for attentions which has a shape + of (num_encoder_layers, left_context, encoder_dim), the second one + is the cache of conv_modules which has a shape of + (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). + NOTE: the returned tensors are on the given device. + """ + if ( + len(self._init_state) == 2 + and self._init_state[0].size(1) == left_context + ): + # Note: It is OK to share the init state as it is + # not going to be modified by the model + return self._init_state + + init_states: List[torch.Tensor] = [ + torch.zeros( + ( + self.encoder_layers, + left_context, + self.d_model, + ), + device=device, + ), + torch.zeros( + ( + self.encoder_layers, + self.cnn_module_kernel - 1, + self.d_model, + ), + device=device, + ), + ] + + self._init_state = init_states + + return init_states + @torch.jit.export def streaming_forward( self, @@ -492,6 +539,98 @@ class ConformerEncoderLayer(nn.Module): return src + @torch.jit.export + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + states: List[Tensor], + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + left_context: int = 0, + right_context: int = 0, + ) -> Tuple[Tensor, List[Tensor]]: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + states: + The decode states for previous frames which contains the cached data. + It has two elements, the first element is the attn_cache which has + a shape of (left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (cnn_module_kernel-1, batch, conv_dim). + Note: states will be modified in this function. + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + left_context: + How many previous frames the attention can see in current chunk. + Note: It's not that each individual frame has `left_context` frames + of left context, some have more. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. + Shape: + src: (S, N, E). + pos_emb: (N, 2*(S+left_context)-1, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + assert not self.training + assert len(states) == 2 + assert states[0].shape == (left_context, src.size(1), src.size(2)) + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + # We put the attention cache this level (i.e. before linear transformation) + # to save memory consumption, when decoding in streaming fashion, the + # batch size would be thousands (for 32GB machine), if we cache key & val + # separately, it needs extra several GB memory. + # TODO(WeiKang): Move cache to self_attn level (i.e. cache key & val + # separately) if needed. + key = torch.cat([states[0], src], dim=0) + val = key + if right_context > 0: + states[0] = key[ + -(left_context + right_context) : -right_context, ... # noqa + ] + else: + states[0] = key[-left_context:, ...] + + # multi-headed self-attention module + src_att = self.self_attn( + src, + key, + val, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + left_context=left_context, + )[0] + + src = src + self.dropout(src_att) + + # convolution module + conv, conv_cache = self.conv_module(src, states[1], right_context) + states[1] = conv_cache + + src = src + self.dropout(conv) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + return src, states + class ConformerEncoder(nn.Module): r"""ConformerEncoder is a stack of N encoder layers @@ -575,6 +714,77 @@ class ConformerEncoder(nn.Module): return output + @torch.jit.export + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + states: List[Tensor], + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + left_context: int = 0, + right_context: int = 0, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Pass the input through the encoder layers in turn. + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + states: + The decode states for previous frames which contains the cached data. + It has two elements, the first element is the attn_cache which has + a shape of (encoder_layers, left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (encoder_layers, cnn_module_kernel-1, batch, conv_dim). + Note: states will be modified in this function. + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + left_context: + How many previous frames the attention can see in current chunk. + Note: It's not that each individual frame has `left_context` frames + of left context, some have more. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. + Shape: + src: (S, N, E). + pos_emb: (N, 2*(S+left_context)-1, E). + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + """ + assert not self.training + assert len(states) == 2 + assert states[0].shape == ( + self.num_layers, + left_context, + src.size(1), + src.size(2), + ) + assert states[1].size(0) == self.num_layers + + output = src + + for layer_index, mod in enumerate(self.layers): + cache = [states[0][layer_index], states[1][layer_index]] + output, cache = mod.chunk_forward( + output, + pos_emb, + states=cache, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + left_context=left_context, + right_context=right_context, + ) + states[0][layer_index] = cache[0] + states[1][layer_index] = cache[1] + + return output, states + class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -599,12 +809,13 @@ class RelPositionalEncoding(torch.nn.Module): self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - def extend_pe(self, x: Tensor) -> None: + def extend_pe(self, x: Tensor, left_context: int = 0) -> None: """Reset the positional encodings.""" + x_size_1 = x.size(1) + left_context if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device if self.pe.dtype != x.dtype or str(self.pe.device) != str( x.device @@ -614,9 +825,9 @@ class RelPositionalEncoding(torch.nn.Module): # Suppose `i` means to the position of query vector and `j` means the # position of key vector. We use position relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + def forward( + self, x: torch.Tensor, left_context: int = 0 + ) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: x (torch.Tensor): Input tensor (batch, time, `*`). + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Returns: torch.Tensor: Encoded tensor (batch, time, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). """ - self.extend_pe(x) + self.extend_pe(x, left_context) + x_size_1 = x.size(1) + left_context pos_emb = self.pe[ :, self.pe.size(1) // 2 - - x.size(1) + - x_size_1 + 1 : self.pe.size(1) // 2 # noqa E203 + x.size(1), ] @@ -721,6 +938,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, + left_context: int = 0, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -734,6 +952,9 @@ class RelPositionMultiheadAttention(nn.Module): need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Shape: - Inputs: @@ -779,14 +1000,18 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, + left_context=left_context, ) - def rel_shift(self, x: Tensor) -> Tensor: + def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor: """Compute relative positional encoding. Args: x: Input tensor (batch, head, time1, 2*time1-1). time1 means the length of query vector. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Returns: Tensor: tensor of shape (batch, head, time1, time2) @@ -794,14 +1019,17 @@ class RelPositionMultiheadAttention(nn.Module): the key, while time1 is for the query). """ (batch_size, num_heads, time1, n) = x.shape - assert n == 2 * time1 - 1 + time2 = time1 + left_context + assert ( + n == left_context + 2 * time1 - 1 + ), f"{n} == {left_context} + 2 * {time1} - 1" # Note: TorchScript requires explicit arg for stride() batch_stride = x.stride(0) head_stride = x.stride(1) time1_stride = x.stride(2) n_stride = x.stride(3) return x.as_strided( - (batch_size, num_heads, time1, time1), + (batch_size, num_heads, time1, time2), (batch_stride, head_stride, time1_stride - n_stride, n_stride), storage_offset=n_stride * (time1 - 1), ) @@ -823,6 +1051,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, + left_context: int = 0, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -840,6 +1069,9 @@ class RelPositionMultiheadAttention(nn.Module): need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Shape: Inputs: @@ -1003,7 +1235,8 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb_bsz = pos_emb.size(0) assert pos_emb_bsz in (1, bsz) # actually it is 1 p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) + p = p.permute(0, 2, 3, 1) q_with_bias_u = (q + self._pos_bias_u()).transpose( 1, 2 @@ -1023,9 +1256,9 @@ class RelPositionMultiheadAttention(nn.Module): # compute matrix b and matrix d matrix_bd = torch.matmul( - q_with_bias_v, p.transpose(-2, -1) + q_with_bias_v, p ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd) + matrix_bd = self.rel_shift(matrix_bd, left_context) attn_output_weights = ( matrix_ac + matrix_bd @@ -1201,13 +1434,19 @@ class ConvolutionModule(nn.Module): initial_scale=0.25, ) - def forward(self, x: Tensor, cache: Optional[Tensor] = None) -> Tensor: + def forward( + self, x: Tensor, cache: Optional[Tensor] = None, right_context: int = 0 + ) -> Tuple[Tensor, Tensor]: """Compute convolution module. Args: x: Input tensor (#time, batch, channels). cache: The cache of depthwise_conv, only used in real streaming decoding. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. Returns: Tensor: Output tensor (#time, batch, channels). @@ -1237,7 +1476,15 @@ class ConvolutionModule(nn.Module): ), "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) - cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa + if right_context > 0: + cache = x.permute(2, 0, 1)[ + -(self.lorder + right_context) : ( # noqa + -right_context + ), + ..., + ] + else: + cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa x = self.depthwise_conv(x) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode_stream.py new file mode 120000 index 000000000..30f264813 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode_stream.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/decode_stream.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index 936508900..fec0254f9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -137,6 +137,15 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--streaming-model", + type=str2bool, + default=False, + help="""Whether to export a streaming model, if the models in exp-dir + are streaming model, this should be True. + """, + ) + add_model_arguments(parser) return parser @@ -162,6 +171,9 @@ def main(): params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.streaming_model: + assert params.causal_convolution + logging.info(params) logging.info("About to create model") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py new file mode 100755 index 000000000..640e8006e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -0,0 +1,750 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless5/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --left-context 32 \ + --decode-chunk-size 8 \ + --right-context 0 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --decoding_method greedy_search \ + --num-decode-streams 200 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import one_best_decoding +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Support only greedy_search and fast_beam_search now. + """, + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--right-context", + type=int, + default=0, + help="right context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], +) -> List[List[int]]: + + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + T = encoder_out.size(1) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (N, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # logging.info(f"decoder_out shape : {decoder_out.shape}") + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + decoder_out = model.joiner.decoder_proj(decoder_out) + + hyp_tokens = [] + for stream in streams: + hyp_tokens.append(stream.hyp) + return hyp_tokens + + +def fast_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + decoding_streams: k2.RnntDecodingStreams, +) -> List[List[int]]: + + B, T, C = encoder_out.shape + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + best_path = one_best_decoding(lattice) + hyp_tokens = get_texts(best_path) + return hyp_tokens + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + + rnnt_stream_list = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames( + params.decode_chunk_size * params.subsampling_factor + ) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + if params.decoding_method == "fast_beam_search": + rnnt_stream_list.append(stream.rnnt_decoding_stream) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # if T is less than 7 there will be an error in time reduction layer, + # because we subsample features with ((x_len - 1) // 2 - 1) // 2 + # we plus 2 here because we will cut off one frame on each size of + # encoder_embed output as they see invalid paddings. so we need extra 2 + # frames. + tail_length = 7 + (2 + params.right_context) * params.subsampling_factor + if features.size(1) < tail_length: + feature_lens += tail_length - features.size(1) + features = torch.cat( + [ + features, + torch.tensor( + LOG_EPS, dtype=features.dtype, device=device + ).expand( + features.size(0), + tail_length - features.size(1), + features.size(2), + ), + ], + dim=1, + ) + + states = [ + torch.stack([x[0] for x in states], dim=2), + torch.stack([x[1] for x in states], dim=2), + ] + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + left_context=params.left_context, + right_context=params.right_context, + processed_lens=processed_lens, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + hyp_tokens = greedy_search(model, encoder_out, decode_streams) + elif params.decoding_method == "fast_beam_search": + config = k2.RnntDecodingConfig( + vocab_size=params.vocab_size, + decoder_history_len=params.context_size, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) + processed_lens = processed_lens + encoder_out_lens + hyp_tokens = fast_beam_search( + model, encoder_out, processed_lens, decoding_streams + ) + else: + assert False + + states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = [states[0][i], states[1][i]] + decode_streams[i].done_frames += encoder_out_lens[i] + if params.decoding_method == "fast_beam_search": + decode_streams[i].hyp = hyp_tokens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + decode_stream = DecodeStream( + params=params, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + hyp = decode_streams[i].hyp + if params.decoding_method == "greedy_search": + hyp = hyp[params.context_size :] # noqa + decode_results.append( + ( + decode_streams[i].ground_truth.split(), + sp.decode(hyp).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + hyp = decode_streams[i].hyp + if params.decoding_method == "greedy_search": + hyp = hyp[params.context_size :] # noqa + decode_results.append( + ( + decode_streams[i].ground_truth.split(), + sp.decode(hyp).split(), + ) + ) + del decode_streams[i] + + key = "greedy_search" + if params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + params.suffix += f"-right-context-{params.right_context}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + # Decoding in streaming requires causal convolution + params.causal_convolution = True + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main()