From f233b16974455d323a0d2dbde4ee34a067440073 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 10 Jun 2022 22:10:43 +0800 Subject: [PATCH] add modify_beam_search, fast_beam_search --- .../stream.py | 15 +- .../streaming_decode.py | 310 ++++++++++++++++-- 2 files changed, 301 insertions(+), 24 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py index 6c7c52df4..6c44f3b39 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py @@ -17,6 +17,7 @@ import math from typing import List, Optional, Tuple +import k2 import torch from beam_search import Hypothesis, HypothesisList from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature @@ -47,6 +48,7 @@ class Stream(object): def __init__( self, params: AttributeDict, + decoding_graph: Optional[k2.Fsa] = None, device: torch.device = torch.device("cpu"), LOG_EPS: float = math.log(1e-10), ) -> None: @@ -80,6 +82,13 @@ class Stream(object): log_prob=torch.zeros(1, dtype=torch.float32, device=device), ) ) + elif params.decoding_method == "fast_beam_search": + # feature_len is needed to get partial results. + # The rnnt_decoding_stream for fast_beam_search. + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) + ) + self.hyp: List[int] = None else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -171,7 +180,9 @@ class Stream(object): """Obtain current decoding result.""" if self.decoding_method == "greedy_search": return self.hyp[self.context_size :] - else: - assert self.decoding_method == "modified_beam_search" + elif self.decoding_method == "modified_beam_search": best_hyp = self.hyps.get_most_probable(length_norm=True) return best_hyp.ys[self.context_size :] + else: + assert self.decoding_method == "fast_beam_search" + return self.hyp diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 35a909397..62ba144b4 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -17,9 +17,7 @@ # limitations under the License. import argparse -import copy import logging -import math import warnings from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -44,8 +42,10 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.decode import one_best_decoding from icefall.utils import ( AttributeDict, + get_texts, setup_logger, store_transcripts, str2bool, @@ -116,7 +116,6 @@ def get_parser(): default="greedy_search", help="""Possible values are: - greedy_search - - beam_search - modified_beam_search - fast_beam_search """, @@ -196,7 +195,15 @@ def greedy_search( encoder_out: torch.Tensor, streams: List[Stream], ) -> List[List[int]]: - + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + streams: + A list of Stream objects. + """ assert len(streams) == encoder_out.size(0) assert encoder_out.ndim == 3 @@ -205,6 +212,8 @@ def greedy_search( device = next(model.parameters()).device T = encoder_out.size(1) + encoder_out = model.joiner.encoder_proj(encoder_out) + decoder_input = torch.tensor( [stream.hyp[-context_size:] for stream in streams], device=device, @@ -248,11 +257,216 @@ def greedy_search( decoder_out = model.joiner.decoder_proj(decoder_out) +def modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[Stream], + beam: int = 4, +): + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The RNN-T model. + encoder_out: + A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of + the encoder model. + streams: + A list of stream objects. + beam: + Number of active paths during the beam search. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert len(streams) == encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + batch_size = len(streams) + T = encoder_out.size(1) + + B = [stream.hyps for stream in streams] + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.stack( + [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + for i in range(batch_size): + streams[i].hyps = B[i] + + +def fast_beam_search_one_best( + model: nn.Module, + streams: List[Stream], + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using modified beam search, and then + the shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + streams: + A list of stream objects. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + processed_lens: + A tensor of shape (N,) containing the number of processed frames + in `encoder_out` before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + assert B == len(streams) + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(streams[i].rnnt_decoding_stream) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + # import pdb + + # pdb.set_trace() + lattice = decoding_streams.format_output(processed_lens.tolist()) + + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + + for i in range(B): + streams[i].hyp = hyps[i] + + def decode_one_chunk( model: nn.Module, streams: List[Stream], params: AttributeDict, - sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, ) -> List[int]: device = next(model.parameters()).device @@ -292,7 +506,8 @@ def decode_one_chunk( mode="constant", value=LOG_EPSILON, ) - # stack states of all streams + + # Stack states of all streams states = stack_states(state_list) encoder_out, encoder_out_lens, states = model.encoder.infer( @@ -301,7 +516,6 @@ def decode_one_chunk( states=states, num_processed_frames=num_processed_frames, ) - encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": greedy_search( @@ -309,20 +523,29 @@ def decode_one_chunk( streams=streams, encoder_out=encoder_out, ) - # elif params.decoding_method == "modified_beam_search": - # modified_beam_search( - # model=model, - # streams=streams, - # encoder_out=encoder_out, - # sp=sp, - # beam=params.beam_size, - # ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=streams, + encoder_out=encoder_out, + beam=params.beam_size, + ) + elif params.decoding_method == "fast_beam_search": + fast_beam_search_one_best( + model=model, + streams=streams, + encoder_out=encoder_out, + processed_lens=(num_processed_frames >> 2) + encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - # update cached states of each stream + # Update cached states of each stream state_list = unstack_states(states) for i, s in enumerate(state_list): streams[i].states = s @@ -355,9 +578,29 @@ def decode_dataset( model: nn.Module, params: AttributeDict, sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, ): """Decode dataset. + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The Transducer model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. """ device = next(model.parameters()).device @@ -369,7 +612,12 @@ def decode_dataset( streams = [] for num, cut in enumerate(cuts): # Each utterance has a Stream. - stream = Stream(params=params, device=device, LOG_EPS=LOG_EPSILON) + stream = Stream( + params=params, + decoding_graph=decoding_graph, + device=device, + LOG_EPS=LOG_EPSILON, + ) audio: np.ndarray = cut.load_audio() # audio.shape: (1, num_samples) @@ -391,7 +639,7 @@ def decode_dataset( model=model, streams=streams, params=params, - sp=sp, + decoding_graph=decoding_graph, ) for i in sorted(finished_streams, reverse=True): @@ -411,7 +659,7 @@ def decode_dataset( model=model, streams=streams, params=params, - sp=sp, + decoding_graph=decoding_graph, ) for i in sorted(finished_streams, reverse=True): @@ -423,10 +671,17 @@ def decode_dataset( ) del streams[i] - if params.decoding_method == "greedy_search": - return {"greedy_search": decode_results} + key = "greedy_search" + if params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) else: - return {f"beam_size_{params.beam_size}": decode_results} + key = f"beam_size_{params.beam_size}" + + return {key: decode_results} def save_results( @@ -483,6 +738,11 @@ def main(): params = get_params() params.update(vars(args)) + assert params.decoding_method in ( + "greedy_search", + "fast_beam_search", + "modified_beam_search", + ) # Note: params.decoding_method is currently not used. params.res_dir = params.exp_dir / "streaming" / params.decoding_method @@ -616,6 +876,11 @@ def main(): model.eval() model.device = device + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -633,6 +898,7 @@ def main(): model=model, params=params, sp=sp, + decoding_graph=decoding_graph, ) save_results(