From 5228b44de7c5b3982e5f7d6ca4ffec6b2fb3c5fe Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 19 Apr 2022 22:00:47 +0800 Subject: [PATCH 1/2] Support modified beam search decoding for streaming inference with Emformer model. --- .../transducer_emformer/streaming_decode.py | 197 +++++++++++++++--- .../streaming_feature_extractor.py | 53 +++-- 2 files changed, 210 insertions(+), 40 deletions(-) diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py index bb71310b7..f5e24a0d9 100755 --- a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py @@ -18,16 +18,23 @@ import argparse import logging +import warnings from pathlib import Path from typing import List, Optional, Tuple +import k2 import numpy as np import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import Hypothesis, HypothesisList, _get_hyps_shape from emformer import LOG_EPSILON, stack_states, unstack_states -from streaming_feature_extractor import FeatureExtractionStream +from streaming_feature_extractor import ( + FeatureExtractionStream, + GreedySearchStream, + ModifiedBeamSearchStream, +) from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -50,6 +57,7 @@ def get_parser(): help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) + parser.add_argument( "--avg", type=int, @@ -208,7 +216,7 @@ class StreamList(object): self, batch_size: int, context_size: int, - blank_id: int, + decoding_method: str, ): """ Args: @@ -216,14 +224,21 @@ class StreamList(object): Size of this batch. context_size: Context size of the RNN-T decoder model. - blank_id: - The ID of the blank symbol of the BPE model. + decoding_method: + Decoding method. The possible values are: + - greedy_search + - modified_beam_search """ + decoding_classes = { + "greedy_search": GreedySearchStream, + "modified_beam_search": ModifiedBeamSearchStream, + } + + assert decoding_method in decoding_classes + cls = decoding_classes[decoding_method] + self.streams = [ - FeatureExtractionStream( - context_size=context_size, blank_id=blank_id - ) - for _ in range(batch_size) + cls(context_size=context_size) for _ in range(batch_size) ] @property @@ -238,7 +253,7 @@ class StreamList(object): audio_samples: List[torch.Tensor], sampling_rate: float, ): - """Feeed audio samples to each stream. + """Feed audio samples to each stream. Args: audio_samples: A list of 1-D tensors containing the audio samples for each @@ -314,7 +329,7 @@ class StreamList(object): def greedy_search( model: nn.Module, - streams: List[FeatureExtractionStream], + streams: List[GreedySearchStream], encoder_out: torch.Tensor, sp: spm.SentencePieceProcessor, ): @@ -333,7 +348,15 @@ def greedy_search( blank_id = model.decoder.blank_id context_size = model.decoder.context_size device = model.device + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + for s in streams: + if s.hyp is None: + s.hyp = Hypothesis( + ys=([blank_id] * context_size), + log_prob=torch.tensor([0.0], device=device), + ) if streams[0].decoder_out is None: decoder_input = torch.tensor( [stream.hyp.ys[-context_size:] for stream in streams], @@ -351,8 +374,6 @@ def greedy_search( dim=0, ) - assert encoder_out.ndim == 3 - T = encoder_out.size(1) for t in range(T): current_encoder_out = encoder_out[:, t] @@ -381,20 +402,132 @@ def greedy_search( ) for k, s in enumerate(streams): - logging.info( - f"Partial result {k}:\n{sp.decode(s.hyp.ys[context_size:])}" - ) + logging.info(f"Partial result {k}:\n{sp.decode(s.result)}") decoder_out_list = decoder_out.unbind(dim=0) - for i, d in enumerate(decoder_out_list): streams[i].decoder_out = d +def modified_beam_search( + model: nn.Module, + streams: List[ModifiedBeamSearchStream], + encoder_out: torch.Tensor, + sp: spm.SentencePieceProcessor, + beam: int = 4, +): + """ + Args: + model: + The RNN-T model. + stream: + A stream object. + encoder_out: + A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of + the encoder model. + sp: + The BPE model. + beam: + Number of active paths during the beam search. + """ + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + assert encoder_out.ndim == 3, encoder_out.shape + assert len(streams) == encoder_out.size(0) + batch_size = len(streams) + + for s in streams: + if len(s.hyps) == 0: + s.hyps.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + B = [s.hyps for s in streams] + + T = encoder_out.size(1) + for t in range(T): + current_encoder_out = encoder_out[:, t] + # current_encoder_out's shape: (batch_size, encoder_out_dim) + + hyps_shape = _get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1) + # decoder_out is of shape (num_hyps, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, encoder_out_dim) + + logits = model.joiner(current_encoder_out, decoder_out) + # logits is of shape (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + streams[i].hyps = B[i] + logging.info(f"Partial result {i}:\n{sp.decode(streams[i].result)}") + + def process_features( model: nn.Module, features: torch.Tensor, streams: List[FeatureExtractionStream], + params: AttributeDict, sp: spm.SentencePieceProcessor, ) -> None: """Process features for each stream in parallel. @@ -406,6 +539,8 @@ def process_features( A 3-D tensor of shape (N, T, C). streams: A list of streams of size (N,). + params: + It is the return value of :func:`get_params`. sp: The BPE model. """ @@ -439,12 +574,25 @@ def process_features( for i, s in enumerate(state_list): streams[i].states = s - greedy_search( - model=model, - streams=streams, - encoder_out=encoder_out, - sp=sp, - ) + if params.decoding_method == "greedy_search": + greedy_search( + model=model, + streams=streams, + encoder_out=encoder_out, + sp=sp, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=streams, + encoder_out=encoder_out, + sp=sp, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) def decode_batch( @@ -479,7 +627,7 @@ def decode_batch( stream_list = StreamList( batch_size=batch_size, context_size=params.context_size, - blank_id=params.blank_id, + decoding_method=params.decoding_method, ) while not streaming_audio_samples.done: @@ -497,11 +645,12 @@ def decode_batch( model=model, features=features, streams=active_streams, + params=params, sp=sp, ) results = [] for s in stream_list.streams: - text = sp.decode(s.hyp.ys[params.context_size :]) + text = sp.decode(s.result) results.append(text) return results diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py index b20f6502f..a040cc09c 100644 --- a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py @@ -17,7 +17,7 @@ from typing import List, Optional import torch -from beam_search import Hypothesis +from beam_search import Hypothesis, HypothesisList from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature @@ -41,14 +41,10 @@ def _create_streaming_feature_extractor() -> OnlineFeature: class FeatureExtractionStream(object): - def __init__(self, context_size: int, blank_id: int = 0) -> None: - """Context size of the RNN-T decoder model.""" + def __init__( + self, + ) -> None: self.feature_extractor = _create_streaming_feature_extractor() - self.hyp = Hypothesis( - ys=([blank_id] * context_size), - log_prob=torch.tensor([0.0]), - ) # for greedy search, will extend it to beam search - # It contains a list of 1-D tensors representing the feature frames. self.feature_frames: List[torch.Tensor] = [] @@ -58,11 +54,6 @@ class FeatureExtractionStream(object): # encoder layer. self.states: Optional[List[List[torch.Tensor]]] = None - # For the RNN-T decoder, it contains the decoder output - # corresponding to the decoder input self.hyp.ys[-context_size:] - # Its shape is (decoder_out_dim,) - self.decoder_out: Optional[torch.Tensor] = None - # After calling `self.input_finished()`, we set this flag to True self._done = False @@ -85,9 +76,9 @@ class FeatureExtractionStream(object): check to ensure that the input sampling rate equals to the one used in the extractor. If they are not equal, then no resampling will be performed; instead an error will be thrown. - waveform: - A 1-D torch tensor of dtype torch.float32 containing audio samples. - It should be on CPU. + waveform: + A 1-D torch tensor of dtype torch.float32 containing audio samples. + It should be on CPU. """ self.feature_extractor.accept_waveform( sampling_rate=sampling_rate, @@ -114,3 +105,33 @@ class FeatureExtractionStream(object): frame = self.feature_extractor.get_frame(self.num_fetched_frames) self.feature_frames.append(frame) self.num_fetched_frames += 1 + + +class GreedySearchStream(FeatureExtractionStream): + def __init__(self, context_size: int) -> None: + """FeatureExtractionStream class for greedy search.""" + super().__init__() + self.context_size = context_size + # For the RNN-T decoder, it contains the decoder output + # corresponding to the decoder input self.hyp.ys[-context_size:] + # Its shape is (decoder_out_dim,) + self.hyp: Hypothesis = None + self.decoder_out: Optional[torch.Tensor] = None + + @property + def result(self) -> List[int]: + return self.hyp.ys[self.context_size :] + + +class ModifiedBeamSearchStream(FeatureExtractionStream): + def __init__(self, context_size: int) -> None: + """FeatureExtractionStream class for modified beam search decoding.""" + super().__init__() + self.context_size = context_size + self.hyps = HypothesisList() + self.best_hyp = None + + @property + def result(self) -> List[int]: + best_hyp = self.hyps.get_most_probable(length_norm=True) + return best_hyp.ys[self.context_size :] From e74654c2a242677bcdc1481b1864515ed52a6f52 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 19 Apr 2022 22:05:14 +0800 Subject: [PATCH 2/2] Formatted imports. --- egs/librispeech/ASR/transducer_emformer/streaming_decode.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py index f5e24a0d9..df3303100 100755 --- a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py @@ -37,11 +37,7 @@ from streaming_feature_extractor import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import AttributeDict, setup_logger