From cf0ce8db322e48b2148c1e6aee59801391b620a8 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 21 Apr 2022 19:48:35 +0800 Subject: [PATCH] Fixed streaming decoding codes for emformer model. --- .../beam_search.py | 4 +- .../transducer_emformer/streaming_decode.py | 105 ++++++++---------- .../streaming_feature_extractor.py | 63 ++++------- 3 files changed, 74 insertions(+), 98 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 2cb7a8cba..574c637ae 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -367,7 +367,7 @@ class HypothesisList(object): return ", ".join(s) -def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: +def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: """Return a ragged shape with axes [utt][num_hyps]. Args: @@ -431,7 +431,7 @@ def modified_beam_search( current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - hyps_shape = _get_hyps_shape(B).to(device) + hyps_shape = get_hyps_shape(B).to(device) A = [list(b) for b in B] B = [HypothesisList() for _ in range(batch_size)] diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py index df3303100..c5bcb3aee 100755 --- a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py @@ -28,16 +28,16 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import Hypothesis, HypothesisList, _get_hyps_shape +from beam_search import Hypothesis, HypothesisList, get_hyps_shape from emformer import LOG_EPSILON, stack_states, unstack_states -from streaming_feature_extractor import ( - FeatureExtractionStream, - GreedySearchStream, - ModifiedBeamSearchStream, -) +from streaming_feature_extractor import FeatureExtractionStream from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import AttributeDict, setup_logger @@ -225,16 +225,12 @@ class StreamList(object): - greedy_search - modified_beam_search """ - decoding_classes = { - "greedy_search": GreedySearchStream, - "modified_beam_search": ModifiedBeamSearchStream, - } - - assert decoding_method in decoding_classes - cls = decoding_classes[decoding_method] self.streams = [ - cls(context_size=context_size) for _ in range(batch_size) + FeatureExtractionStream( + context_size=context_size, decoding_method=decoding_method + ) + for _ in range(batch_size) ] @property @@ -325,7 +321,7 @@ class StreamList(object): def greedy_search( model: nn.Module, - streams: List[GreedySearchStream], + streams: List[FeatureExtractionStream], encoder_out: torch.Tensor, sp: spm.SentencePieceProcessor, ): @@ -333,36 +329,31 @@ def greedy_search( Args: model: The RNN-T model. - stream: - A stream object. + streams: + A list of GreedySearchDecodingStream objects. encoder_out: A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of the encoder model. sp: The BPE model. """ - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = model.device assert len(streams) == encoder_out.size(0) assert encoder_out.ndim == 3 - for s in streams: - if s.hyp is None: - s.hyp = Hypothesis( - ys=([blank_id] * context_size), - log_prob=torch.tensor([0.0], device=device), - ) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + T = encoder_out.size(1) + if streams[0].decoder_out is None: + for stream in streams: + stream.hyp = [blank_id] * context_size decoder_input = torch.tensor( - [stream.hyp.ys[-context_size:] for stream in streams], + [stream.hyp[-context_size:] for stream in streams], device=device, dtype=torch.int64, ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ).squeeze(1) + decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1) # decoder_out is of shape (N, decoder_out_dim) else: decoder_out = torch.stack( @@ -370,7 +361,6 @@ def greedy_search( dim=0, ) - T = encoder_out.size(1) for t in range(T): current_encoder_out = encoder_out[:, t] # current_encoder_out's shape: (batch_size, encoder_out_dim) @@ -383,22 +373,23 @@ def greedy_search( emitted = False for i, v in enumerate(y): if v != blank_id: - streams[i].hyp.ys.append(v) + streams[i].hyp.append(v) emitted = True - if emitted: # update decoder output decoder_input = torch.tensor( - [stream.hyp.ys[-context_size:] for stream in streams], + [stream.hyp[-context_size:] for stream in streams], device=device, dtype=torch.int64, ) - decoder_out = model.decoder(decoder_input, need_pad=False).squeeze( - 1 - ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ).squeeze(1) - for k, s in enumerate(streams): - logging.info(f"Partial result {k}:\n{sp.decode(s.result)}") + for k, stream in enumerate(streams): + result = sp.decode(stream.decoding_result()) + logging.info(f"Partial result {k}:\n{result}") decoder_out_list = decoder_out.unbind(dim=0) for i, d in enumerate(decoder_out_list): @@ -407,7 +398,7 @@ def greedy_search( def modified_beam_search( model: nn.Module, - streams: List[ModifiedBeamSearchStream], + streams: List[FeatureExtractionStream], encoder_out: torch.Tensor, sp: spm.SentencePieceProcessor, beam: int = 4, @@ -426,36 +417,35 @@ def modified_beam_search( beam: Number of active paths during the beam search. """ + assert encoder_out.ndim == 3, encoder_out.shape + assert len(streams) == encoder_out.size(0) + blank_id = model.decoder.blank_id context_size = model.decoder.context_size device = model.device - assert encoder_out.ndim == 3, encoder_out.shape - assert len(streams) == encoder_out.size(0) batch_size = len(streams) + T = encoder_out.size(1) - for s in streams: - if len(s.hyps) == 0: - s.hyps.add( + for stream in streams: + if len(stream.hyps) == 0: + stream.hyps.add( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), ) ) - - B = [s.hyps for s in streams] - - T = encoder_out.size(1) + B = [stream.hyps for stream in streams] for t in range(T): current_encoder_out = encoder_out[:, t] # current_encoder_out's shape: (batch_size, encoder_out_dim) - hyps_shape = _get_hyps_shape(B).to(device) + hyps_shape = get_hyps_shape(B).to(device) A = [list(b) for b in B] B = [HypothesisList() for _ in range(batch_size)] - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ys_log_probs = torch.stack( + [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 ) # (num_hyps, 1) decoder_input = torch.tensor( @@ -516,7 +506,8 @@ def modified_beam_search( B[i].add(new_hyp) streams[i].hyps = B[i] - logging.info(f"Partial result {i}:\n{sp.decode(streams[i].result)}") + result = sp.decode(streams[i].decoding_result()) + logging.info(f"Partial result {i}:\n{result}") def process_features( @@ -645,8 +636,8 @@ def decode_batch( sp=sp, ) results = [] - for s in stream_list.streams: - text = sp.decode(s.result) + for stream in stream_list.streams: + text = sp.decode(stream.decoding_result()) results.append(text) return results diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py index a040cc09c..c3d9a5675 100644 --- a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py @@ -14,11 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional - -import torch -from beam_search import Hypothesis, HypothesisList +from beam_search import HypothesisList from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature +from typing import List, Optional +import torch def _create_streaming_feature_extractor() -> OnlineFeature: @@ -41,21 +40,28 @@ def _create_streaming_feature_extractor() -> OnlineFeature: class FeatureExtractionStream(object): - def __init__( - self, - ) -> None: + def __init__(self, context_size: int, decoding_method: str) -> None: self.feature_extractor = _create_streaming_feature_extractor() # It contains a list of 1-D tensors representing the feature frames. self.feature_frames: List[torch.Tensor] = [] - self.num_fetched_frames = 0 + # After calling `self.input_finished()`, we set this flag to True + self._done = False # For the emformer model, it contains the states of each # encoder layer. self.states: Optional[List[List[torch.Tensor]]] = None - # After calling `self.input_finished()`, we set this flag to True - self._done = False + # It use different attributes for different decoding methods. + self.context_size = context_size + self.decoding_method = decoding_method + if decoding_method == "greedy_search": + self.hyp: List[int] = None + self.decoder_out: Optional[torch.Tensor] = None + elif decoding_method == "modified_beam_search": + self.hyps = HypothesisList() + else: + raise ValueError(f"Unsupported decoding method: {decoding_method}") def accept_waveform( self, @@ -106,32 +112,11 @@ class FeatureExtractionStream(object): self.feature_frames.append(frame) self.num_fetched_frames += 1 - -class GreedySearchStream(FeatureExtractionStream): - def __init__(self, context_size: int) -> None: - """FeatureExtractionStream class for greedy search.""" - super().__init__() - self.context_size = context_size - # For the RNN-T decoder, it contains the decoder output - # corresponding to the decoder input self.hyp.ys[-context_size:] - # Its shape is (decoder_out_dim,) - self.hyp: Hypothesis = None - self.decoder_out: Optional[torch.Tensor] = None - - @property - def result(self) -> List[int]: - return self.hyp.ys[self.context_size :] - - -class ModifiedBeamSearchStream(FeatureExtractionStream): - def __init__(self, context_size: int) -> None: - """FeatureExtractionStream class for modified beam search decoding.""" - super().__init__() - self.context_size = context_size - self.hyps = HypothesisList() - self.best_hyp = None - - @property - def result(self) -> List[int]: - best_hyp = self.hyps.get_most_probable(length_norm=True) - return best_hyp.ys[self.context_size :] + def decoding_result(self) -> List[int]: + """Obtain current decoding result.""" + if self.decoding_method == "greedy_search": + return self.hyp[self.context_size :] + else: + assert self.decoding_method == "modified_beam_search" + best_hyp = self.hyps.get_most_probable(length_norm=True) + return best_hyp.ys[self.context_size :]