From b1d0956855d5a113a1f3b25eadf7bd60f8a61935 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Mon, 25 Jul 2022 16:53:23 +0800 Subject: [PATCH] Add modified_beam_search for streaming decode (#489) * Add modified_beam_search for pruned_transducer_stateless/streaming_decode.py * refactor * modified beam search for stateless3,4 * Fix comments * Add real streamng ci --- ...pruned-transducer-stateless2-2022-06-26.sh | 16 +- .../beam_search.py | 4 +- .../decode_stream.py | 29 +- .../streaming_beam_search.py | 280 +++++++++++++++++ .../streaming_decode.py | 202 ++++-------- .../streaming_beam_search.py | 288 ++++++++++++++++++ .../streaming_decode.py | 202 ++++-------- .../streaming_beam_search.py | 1 + .../streaming_decode.py | 206 ++++--------- .../streaming_beam_search.py | 1 + .../streaming_decode.py | 206 ++++--------- 11 files changed, 843 insertions(+), 592 deletions(-) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless3/streaming_beam_search.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless4/streaming_beam_search.py diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh index 85bbb919f..d9dc34e48 100755 --- a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh +++ b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh @@ -70,7 +70,7 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == max_duration=100 for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" + log "Simulate streaming decoding with $method" ./pruned_transducer_stateless2/decode.py \ --decoding-method $method \ @@ -82,5 +82,19 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == --causal-convolution 1 done + for method in greedy_search fast_beam_search modified_beam_search; do + log "Real streaming decoding with $method" + + ./pruned_transducer_stateless2/streaming_decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --num-decode-streams 100 \ + --exp-dir pruned_transducer_stateless2/exp \ + --left-context 32 \ + --decode-chunk-size 8 \ + --right-context 0 + done + rm pruned_transducer_stateless2/exp/*.pt fi diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 40c442e7a..7af9cc3d7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -751,7 +751,7 @@ class HypothesisList(object): return ", ".join(s) -def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: +def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: """Return a ragged shape with axes [utt][num_hyps]. Args: @@ -847,7 +847,7 @@ def modified_beam_search( finalized_B = B[batch_size:] + finalized_B B = B[:batch_size] - hyps_shape = _get_hyps_shape(B).to(device) + hyps_shape = get_hyps_shape(B).to(device) A = [list(b) for b in B] B = [HypothesisList() for _ in range(batch_size)] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index ba5e80555..6c0e9ba19 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -19,6 +19,7 @@ from typing import List, Optional, Tuple import k2 import torch +from beam_search import Hypothesis, HypothesisList from icefall.utils import AttributeDict @@ -42,7 +43,8 @@ class DecodeStream(object): device: The device to run this stream. """ - if decoding_graph is not None: + if params.decoding_method == "fast_beam_search": + assert decoding_graph is not None assert device == decoding_graph.device self.params = params @@ -77,15 +79,23 @@ class DecodeStream(object): if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size + elif params.decoding_method == "modified_beam_search": + self.hyps = HypothesisList() + self.hyps.add( + Hypothesis( + ys=[params.blank_id] * params.context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) elif params.decoding_method == "fast_beam_search": # The rnnt_decoding_stream for fast_beam_search. self.rnnt_decoding_stream: k2.RnntDecodingStream = ( k2.RnntDecodingStream(decoding_graph) ) else: - assert ( - False - ), f"Decoding method :{params.decoding_method} do not support." + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) @property def done(self) -> bool: @@ -124,3 +134,14 @@ class DecodeStream(object): self._done = True return ret_features, ret_length + + def decoding_result(self) -> List[int]: + """Obtain current decoding result.""" + if self.params.decoding_method == "greedy_search": + return self.hyp[self.params.context_size :] # noqa + elif self.params.decoding_method == "modified_beam_search": + best_hyp = self.hyps.get_most_probable(length_norm=True) + return best_hyp.ys[self.params.context_size :] # noqa + else: + assert self.params.decoding_method == "fast_beam_search" + return self.hyp diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py new file mode 100644 index 000000000..dcf6dc42f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py @@ -0,0 +1,280 @@ +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List + +import k2 +import torch +import torch.nn as nn +from beam_search import Hypothesis, HypothesisList, get_hyps_shape +from decode_stream import DecodeStream + +from icefall.decode import one_best_decoding +from icefall.utils import get_texts + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], +) -> None: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + streams: + A list of Stream objects. + """ + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + T = encoder_out.size(1) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (N, 1, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + + +def modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], + num_active_paths: int = 4, +) -> None: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The RNN-T model. + encoder_out: + A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of + the encoder model. + streams: + A list of stream objects. + num_active_paths: + Number of active paths during the beam search. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert len(streams) == encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + batch_size = len(streams) + T = encoder_out.size(1) + + B = [stream.hyps for stream in streams] + + for t in range(T): + current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.stack( + [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, encoder_out_dim) + + logits = model.joiner(current_encoder_out, decoder_out) + # logits is of shape (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk( + num_active_paths + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + for i in range(batch_size): + streams[i].hyps = B[i] + + +def fast_beam_search_one_best( + model: nn.Module, + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + streams: List[DecodeStream], + beam: float, + max_states: int, + max_contexts: int, +) -> None: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first generated by Fsa-based beam search, then we get the + recognition by applying shortest path on the lattice. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + processed_lens: + A tensor of shape (N,) containing the number of processed frames + in `encoder_out` before padding. + streams: + A list of stream objects. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + """ + assert encoder_out.ndim == 3 + B, T, C = encoder_out.shape + assert B == len(streams) + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(streams[i].rnnt_decoding_stream) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + best_path = one_best_decoding(lattice) + hyp_tokens = get_texts(best_path) + + for i in range(B): + streams[i].hyp = hyp_tokens[i] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index f05cf7a91..e455627f3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -17,13 +17,13 @@ """ Usage: -./pruned_transducer_stateless2/streaming_decode.py \ +./pruned_transducer_stateless/streaming_decode.py \ --epoch 28 \ --avg 15 \ --decode-chunk-size 8 \ --left-context 32 \ --right-context 0 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless/exp \ --decoding_method greedy_search \ --num-decode-streams 1000 """ @@ -43,6 +43,11 @@ from asr_datamodule import LibriSpeechAsrDataModule from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model @@ -51,10 +56,8 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.decode import one_best_decoding from icefall.utils import ( AttributeDict, - get_texts, setup_logger, store_transcripts, write_error_stats, @@ -114,10 +117,21 @@ def get_parser(): "--decoding-method", type=str, default="greedy_search", - help="""Support only greedy_search and fast_beam_search now. + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search """, ) + parser.add_argument( + "--num-active-paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + parser.add_argument( "--beam", type=float, @@ -185,103 +199,6 @@ def get_parser(): return parser -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], -) -> List[List[int]]: - - assert len(streams) == encoder_out.size(0) - assert encoder_out.ndim == 3 - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = model.device - T = encoder_out.size(1) - - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - # decoder_out is of shape (N, decoder_out_dim) - decoder_out = model.decoder(decoder_input, need_pad=False) - - for t in range(T): - # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - ) - # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - streams[i].hyp.append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) - - hyp_tokens = [] - for stream in streams: - hyp_tokens.append(stream.hyp) - return hyp_tokens - - -def fast_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - processed_lens: torch.Tensor, - decoding_streams: k2.RnntDecodingStreams, -) -> List[List[int]]: - - B, T, C = encoder_out.shape - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - - decoding_streams.terminate_and_flush_to_streams() - - lattice = decoding_streams.format_output(processed_lens.tolist()) - best_path = one_best_decoding(lattice) - hyp_tokens = get_texts(best_path) - return hyp_tokens - - def decode_one_chunk( params: AttributeDict, model: nn.Module, @@ -305,8 +222,6 @@ def decode_one_chunk( features = [] feature_lens = [] states = [] - - rnnt_stream_list = [] processed_lens = [] for stream in decode_streams: @@ -317,8 +232,6 @@ def decode_one_chunk( feature_lens.append(feat_len) states.append(stream.states) processed_lens.append(stream.done_frames) - if params.decoding_method == "fast_beam_search": - rnnt_stream_list.append(stream.rnnt_decoding_stream) feature_lens = torch.tensor(feature_lens, device=device) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) @@ -330,19 +243,13 @@ def decode_one_chunk( # frames. tail_length = 7 + (2 + params.right_context) * params.subsampling_factor if features.size(1) < tail_length: - feature_lens += tail_length - features.size(1) - features = torch.cat( - [ - features, - torch.tensor( - LOG_EPS, dtype=features.dtype, device=device - ).expand( - features.size(0), - tail_length - features.size(1), - features.size(2), - ), - ], - dim=1, + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, ) states = [ @@ -362,22 +269,31 @@ def decode_one_chunk( ) if params.decoding_method == "greedy_search": - hyp_tokens = greedy_search(model, encoder_out, decode_streams) - elif params.decoding_method == "fast_beam_search": - config = k2.RnntDecodingConfig( - vocab_size=params.vocab_size, - decoder_history_len=params.context_size, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams ) - decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) + elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens - hyp_tokens = fast_beam_search( - model, encoder_out, processed_lens, decoding_streams + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, ) else: - assert False + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -385,8 +301,6 @@ def decode_one_chunk( for i in range(len(decode_streams)): decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].done_frames += encoder_out_lens[i] - if params.decoding_method == "fast_beam_search": - decode_streams[i].hyp = hyp_tokens[i] if decode_streams[i].done: finished_streams.append(i) @@ -469,13 +383,10 @@ def decode_dataset( params=params, model=model, decode_streams=decode_streams ) for i in sorted(finished_streams, reverse=True): - hyp = decode_streams[i].hyp - if params.decoding_method == "greedy_search": - hyp = hyp[params.context_size :] # noqa decode_results.append( ( decode_streams[i].ground_truth.split(), - sp.decode(hyp).split(), + sp.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] @@ -489,24 +400,29 @@ def decode_dataset( params=params, model=model, decode_streams=decode_streams ) for i in sorted(finished_streams, reverse=True): - hyp = decode_streams[i].hyp - if params.decoding_method == "greedy_search": - hyp = hyp[params.context_size :] # noqa decode_results.append( ( decode_streams[i].ground_truth.split(), - sp.decode(hyp).split(), + sp.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] - key = "greedy_search" - if params.decoding_method == "fast_beam_search": + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": key = ( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" f"max_states_{params.max_states}" ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + return {key: decode_results} diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py new file mode 100644 index 000000000..9bcd2f9f9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py @@ -0,0 +1,288 @@ +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List + +import k2 +import torch +import torch.nn as nn +from beam_search import Hypothesis, HypothesisList, get_hyps_shape +from decode_stream import DecodeStream + +from icefall.decode import one_best_decoding +from icefall.utils import get_texts + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], +) -> None: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + streams: + A list of Stream objects. + """ + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + T = encoder_out.size(1) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (N, 1, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + decoder_out = model.joiner.decoder_proj(decoder_out) + + +def modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], + num_active_paths: int = 4, +) -> None: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The RNN-T model. + encoder_out: + A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of + the encoder model. + streams: + A list of stream objects. + num_active_paths: + Number of active paths during the beam search. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert len(streams) == encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + batch_size = len(streams) + T = encoder_out.size(1) + + B = [stream.hyps for stream in streams] + + for t in range(T): + current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.stack( + [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk( + num_active_paths + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + for i in range(batch_size): + streams[i].hyps = B[i] + + +def fast_beam_search_one_best( + model: nn.Module, + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + streams: List[DecodeStream], + beam: float, + max_states: int, + max_contexts: int, +) -> None: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first generated by Fsa-based beam search, then we get the + recognition by applying shortest path on the lattice. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + processed_lens: + A tensor of shape (N,) containing the number of processed frames + in `encoder_out` before padding. + streams: + A list of stream objects. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + """ + assert encoder_out.ndim == 3 + B, T, C = encoder_out.shape + assert B == len(streams) + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(streams[i].rnnt_decoding_stream) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + best_path = one_best_decoding(lattice) + hyp_tokens = get_texts(best_path) + + for i in range(B): + streams[i].hyp = hyp_tokens[i] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index b3e1f04c3..79963c968 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -43,6 +43,11 @@ from asr_datamodule import LibriSpeechAsrDataModule from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model @@ -51,10 +56,8 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.decode import one_best_decoding from icefall.utils import ( AttributeDict, - get_texts, setup_logger, store_transcripts, write_error_stats, @@ -114,10 +117,21 @@ def get_parser(): "--decoding-method", type=str, default="greedy_search", - help="""Support only greedy_search and fast_beam_search now. + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search """, ) + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + parser.add_argument( "--beam", type=float, @@ -185,109 +199,6 @@ def get_parser(): return parser -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], -) -> List[List[int]]: - - assert len(streams) == encoder_out.size(0) - assert encoder_out.ndim == 3 - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = model.device - T = encoder_out.size(1) - - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - # decoder_out is of shape (N, decoder_out_dim) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # logging.info(f"decoder_out shape : {decoder_out.shape}") - - for t in range(T): - # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - streams[i].hyp.append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) - decoder_out = model.joiner.decoder_proj(decoder_out) - - hyp_tokens = [] - for stream in streams: - hyp_tokens.append(stream.hyp) - return hyp_tokens - - -def fast_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - processed_lens: torch.Tensor, - decoding_streams: k2.RnntDecodingStreams, -) -> List[List[int]]: - - B, T, C = encoder_out.shape - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - - decoding_streams.terminate_and_flush_to_streams() - - lattice = decoding_streams.format_output(processed_lens.tolist()) - best_path = one_best_decoding(lattice) - hyp_tokens = get_texts(best_path) - return hyp_tokens - - def decode_one_chunk( params: AttributeDict, model: nn.Module, @@ -312,7 +223,6 @@ def decode_one_chunk( feature_lens = [] states = [] - rnnt_stream_list = [] processed_lens = [] for stream in decode_streams: @@ -323,8 +233,6 @@ def decode_one_chunk( feature_lens.append(feat_len) states.append(stream.states) processed_lens.append(stream.done_frames) - if params.decoding_method == "fast_beam_search": - rnnt_stream_list.append(stream.rnnt_decoding_stream) feature_lens = torch.tensor(feature_lens, device=device) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) @@ -336,19 +244,13 @@ def decode_one_chunk( # frames. tail_length = 7 + (2 + params.right_context) * params.subsampling_factor if features.size(1) < tail_length: - feature_lens += tail_length - features.size(1) - features = torch.cat( - [ - features, - torch.tensor( - LOG_EPS, dtype=features.dtype, device=device - ).expand( - features.size(0), - tail_length - features.size(1), - features.size(2), - ), - ], - dim=1, + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, ) states = [ @@ -369,22 +271,31 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - hyp_tokens = greedy_search(model, encoder_out, decode_streams) - elif params.decoding_method == "fast_beam_search": - config = k2.RnntDecodingConfig( - vocab_size=params.vocab_size, - decoder_history_len=params.context_size, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams ) - decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) + elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens - hyp_tokens = fast_beam_search( - model, encoder_out, processed_lens, decoding_streams + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, ) else: - assert False + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -392,8 +303,6 @@ def decode_one_chunk( for i in range(len(decode_streams)): decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].done_frames += encoder_out_lens[i] - if params.decoding_method == "fast_beam_search": - decode_streams[i].hyp = hyp_tokens[i] if decode_streams[i].done: finished_streams.append(i) @@ -477,13 +386,10 @@ def decode_dataset( params=params, model=model, decode_streams=decode_streams ) for i in sorted(finished_streams, reverse=True): - hyp = decode_streams[i].hyp - if params.decoding_method == "greedy_search": - hyp = hyp[params.context_size :] # noqa decode_results.append( ( decode_streams[i].ground_truth.split(), - sp.decode(hyp).split(), + sp.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] @@ -497,24 +403,28 @@ def decode_dataset( params=params, model=model, decode_streams=decode_streams ) for i in sorted(finished_streams, reverse=True): - hyp = decode_streams[i].hyp - if params.decoding_method == "greedy_search": - hyp = hyp[params.context_size :] # noqa decode_results.append( ( decode_streams[i].ground_truth.split(), - sp.decode(hyp).split(), + sp.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] - key = "greedy_search" - if params.decoding_method == "fast_beam_search": + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": key = ( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" f"max_states_{params.max_states}" ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_beam_search.py new file mode 120000 index 000000000..3a5f89833 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/streaming_beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 8af2788be..1976d19a6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -17,13 +17,13 @@ """ Usage: -./pruned_transducer_stateless2/streaming_decode.py \ +./pruned_transducer_stateless3/streaming_decode.py \ --epoch 28 \ --avg 15 \ --left-context 32 \ --decode-chunk-size 8 \ --right-context 0 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless3/exp \ --decoding_method greedy_search \ --num-decode-streams 1000 """ @@ -44,6 +44,11 @@ from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet from librispeech import LibriSpeech +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model @@ -52,10 +57,8 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.decode import one_best_decoding from icefall.utils import ( AttributeDict, - get_texts, setup_logger, store_transcripts, write_error_stats, @@ -115,10 +118,21 @@ def get_parser(): "--decoding-method", type=str, default="greedy_search", - help="""Support only greedy_search and fast_beam_search now. + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search """, ) + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + parser.add_argument( "--beam", type=float, @@ -186,109 +200,6 @@ def get_parser(): return parser -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], -) -> List[List[int]]: - - assert len(streams) == encoder_out.size(0) - assert encoder_out.ndim == 3 - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = model.device - T = encoder_out.size(1) - - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - # decoder_out is of shape (N, decoder_out_dim) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # logging.info(f"decoder_out shape : {decoder_out.shape}") - - for t in range(T): - # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - streams[i].hyp.append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) - decoder_out = model.joiner.decoder_proj(decoder_out) - - hyp_tokens = [] - for stream in streams: - hyp_tokens.append(stream.hyp) - return hyp_tokens - - -def fast_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - processed_lens: torch.Tensor, - decoding_streams: k2.RnntDecodingStreams, -) -> List[List[int]]: - - B, T, C = encoder_out.shape - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - - decoding_streams.terminate_and_flush_to_streams() - - lattice = decoding_streams.format_output(processed_lens.tolist()) - best_path = one_best_decoding(lattice) - hyp_tokens = get_texts(best_path) - return hyp_tokens - - def decode_one_chunk( params: AttributeDict, model: nn.Module, @@ -313,7 +224,6 @@ def decode_one_chunk( feature_lens = [] states = [] - rnnt_stream_list = [] processed_lens = [] for stream in decode_streams: @@ -324,8 +234,6 @@ def decode_one_chunk( feature_lens.append(feat_len) states.append(stream.states) processed_lens.append(stream.done_frames) - if params.decoding_method == "fast_beam_search": - rnnt_stream_list.append(stream.rnnt_decoding_stream) feature_lens = torch.tensor(feature_lens, device=device) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) @@ -337,19 +245,13 @@ def decode_one_chunk( # frames. tail_length = 7 + (2 + params.right_context) * params.subsampling_factor if features.size(1) < tail_length: - feature_lens += tail_length - features.size(1) - features = torch.cat( - [ - features, - torch.tensor( - LOG_EPS, dtype=features.dtype, device=device - ).expand( - features.size(0), - tail_length - features.size(1), - features.size(2), - ), - ], - dim=1, + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, ) states = [ @@ -370,22 +272,31 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - hyp_tokens = greedy_search(model, encoder_out, decode_streams) - elif params.decoding_method == "fast_beam_search": - config = k2.RnntDecodingConfig( - vocab_size=params.vocab_size, - decoder_history_len=params.context_size, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams ) - decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) + elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens - hyp_tokens = fast_beam_search( - model, encoder_out, processed_lens, decoding_streams + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, ) else: - assert False + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -393,8 +304,6 @@ def decode_one_chunk( for i in range(len(decode_streams)): decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].done_frames += encoder_out_lens[i] - if params.decoding_method == "fast_beam_search": - decode_streams[i].hyp = hyp_tokens[i] if decode_streams[i].done: finished_streams.append(i) @@ -478,13 +387,10 @@ def decode_dataset( params=params, model=model, decode_streams=decode_streams ) for i in sorted(finished_streams, reverse=True): - hyp = decode_streams[i].hyp - if params.decoding_method == "greedy_search": - hyp = hyp[params.context_size :] # noqa decode_results.append( ( decode_streams[i].ground_truth.split(), - sp.decode(hyp).split(), + sp.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] @@ -498,24 +404,28 @@ def decode_dataset( params=params, model=model, decode_streams=decode_streams ) for i in sorted(finished_streams, reverse=True): - hyp = decode_streams[i].hyp - if params.decoding_method == "greedy_search": - hyp = hyp[params.context_size :] # noqa decode_results.append( ( decode_streams[i].ground_truth.split(), - sp.decode(hyp).split(), + sp.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] - key = "greedy_search" - if params.decoding_method == "fast_beam_search": + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": key = ( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" f"max_states_{params.max_states}" ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_beam_search.py new file mode 120000 index 000000000..3a5f89833 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/streaming_beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index 57fd06980..de89d41c2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -17,13 +17,13 @@ """ Usage: -./pruned_transducer_stateless2/streaming_decode.py \ +./pruned_transducer_stateless4/streaming_decode.py \ --epoch 28 \ --avg 15 \ --left-context 32 \ --decode-chunk-size 8 \ --right-context 0 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless4/exp \ --decoding_method greedy_search \ --num-decode-streams 200 """ @@ -43,6 +43,11 @@ from asr_datamodule import LibriSpeechAsrDataModule from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model @@ -52,10 +57,8 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.decode import one_best_decoding from icefall.utils import ( AttributeDict, - get_texts, setup_logger, store_transcripts, str2bool, @@ -127,10 +130,21 @@ def get_parser(): "--decoding-method", type=str, default="greedy_search", - help="""Support only greedy_search and fast_beam_search now. + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search """, ) + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + parser.add_argument( "--beam", type=float, @@ -198,109 +212,6 @@ def get_parser(): return parser -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], -) -> List[List[int]]: - - assert len(streams) == encoder_out.size(0) - assert encoder_out.ndim == 3 - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = model.device - T = encoder_out.size(1) - - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - # decoder_out is of shape (N, decoder_out_dim) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # logging.info(f"decoder_out shape : {decoder_out.shape}") - - for t in range(T): - # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - streams[i].hyp.append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) - decoder_out = model.joiner.decoder_proj(decoder_out) - - hyp_tokens = [] - for stream in streams: - hyp_tokens.append(stream.hyp) - return hyp_tokens - - -def fast_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - processed_lens: torch.Tensor, - decoding_streams: k2.RnntDecodingStreams, -) -> List[List[int]]: - - B, T, C = encoder_out.shape - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - - decoding_streams.terminate_and_flush_to_streams() - - lattice = decoding_streams.format_output(processed_lens.tolist()) - best_path = one_best_decoding(lattice) - hyp_tokens = get_texts(best_path) - return hyp_tokens - - def decode_one_chunk( params: AttributeDict, model: nn.Module, @@ -325,7 +236,6 @@ def decode_one_chunk( feature_lens = [] states = [] - rnnt_stream_list = [] processed_lens = [] for stream in decode_streams: @@ -336,8 +246,6 @@ def decode_one_chunk( feature_lens.append(feat_len) states.append(stream.states) processed_lens.append(stream.done_frames) - if params.decoding_method == "fast_beam_search": - rnnt_stream_list.append(stream.rnnt_decoding_stream) feature_lens = torch.tensor(feature_lens, device=device) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) @@ -349,19 +257,13 @@ def decode_one_chunk( # frames. tail_length = 7 + (2 + params.right_context) * params.subsampling_factor if features.size(1) < tail_length: - feature_lens += tail_length - features.size(1) - features = torch.cat( - [ - features, - torch.tensor( - LOG_EPS, dtype=features.dtype, device=device - ).expand( - features.size(0), - tail_length - features.size(1), - features.size(2), - ), - ], - dim=1, + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, ) states = [ @@ -382,22 +284,31 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - hyp_tokens = greedy_search(model, encoder_out, decode_streams) - elif params.decoding_method == "fast_beam_search": - config = k2.RnntDecodingConfig( - vocab_size=params.vocab_size, - decoder_history_len=params.context_size, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams ) - decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) + elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens - hyp_tokens = fast_beam_search( - model, encoder_out, processed_lens, decoding_streams + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, ) else: - assert False + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -405,8 +316,6 @@ def decode_one_chunk( for i in range(len(decode_streams)): decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].done_frames += encoder_out_lens[i] - if params.decoding_method == "fast_beam_search": - decode_streams[i].hyp = hyp_tokens[i] if decode_streams[i].done: finished_streams.append(i) @@ -490,13 +399,10 @@ def decode_dataset( params=params, model=model, decode_streams=decode_streams ) for i in sorted(finished_streams, reverse=True): - hyp = decode_streams[i].hyp - if params.decoding_method == "greedy_search": - hyp = hyp[params.context_size :] # noqa decode_results.append( ( decode_streams[i].ground_truth.split(), - sp.decode(hyp).split(), + sp.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] @@ -510,24 +416,28 @@ def decode_dataset( params=params, model=model, decode_streams=decode_streams ) for i in sorted(finished_streams, reverse=True): - hyp = decode_streams[i].hyp - if params.decoding_method == "greedy_search": - hyp = hyp[params.context_size :] # noqa decode_results.append( ( decode_streams[i].ground_truth.split(), - sp.decode(hyp).split(), + sp.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] - key = "greedy_search" - if params.decoding_method == "fast_beam_search": + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": key = ( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" f"max_states_{params.max_states}" ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results}