diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py index 3a08b100d..3bc81a672 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py @@ -88,7 +88,7 @@ def fast_beam_search( # (shape.NumElements(), 1, encoder_out_dim) # fmt: off current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1) + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).long() ) # fmt: on logits = model.joiner( @@ -486,10 +486,7 @@ def modified_beam_search( for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - topk_hyp_indexes = torch.div( - topk_indexes, vocab_size, rounding_mode="trunc" - ) - topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() topk_token_indexes = (topk_indexes % vocab_size).tolist() for k in range(len(topk_hyp_indexes)): diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py index 2c795ede0..0ef774b38 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py @@ -36,7 +36,6 @@ Usage: /path/to/foo.wav \ /path/to/bar.wav - (3) modified beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ @@ -46,6 +45,17 @@ Usage: /path/to/foo.wav \ /path/to/bar.wav +(4) fast beam search +./pruned_transducer_stateless/pretrained.py \ + --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + /path/to/foo.wav \ + /path/to/bar.wav + You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`. Note: ./pruned_transducer_stateless/exp/pretrained.pt is generated by @@ -58,12 +68,19 @@ import logging import math from typing import List +import k2 import kaldifeat import sentencepiece as spm import torch import torch.nn as nn import torchaudio -from beam_search import beam_search, greedy_search, modified_beam_search +from beam_search import ( + beam_search, + fast_beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -97,12 +114,14 @@ def get_parser(): ) parser.add_argument( - "--method", + "--decoding-method", type=str, default="greedy_search", help="""Possible values are: - greedy_search - beam_search + - modified_beam_search + - fast_beam_search """, ) @@ -123,6 +142,32 @@ def get_parser(): help="Used only when --method is beam_search and modified_beam_search ", ) + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + parser.add_argument( "--context-size", type=int, @@ -134,7 +179,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --method is greedy_search. """, @@ -268,6 +313,11 @@ def main(): model.eval() model.device = device + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + logging.info("Constructing Fbank computer") opts = kaldifeat.FbankOptions() opts.device = device @@ -299,34 +349,66 @@ def main(): x=features, x_lens=feature_lengths ) - num_waves = encoder_out.size(0) hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": + msg = f"Using {params.decoding_method}" + if params.decoding_method == "beam_search": msg += f" with beam size {params.beam_size}" logging.info(msg) - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) s = "\n" for filename, hyp in zip(params.sound_files, hyps):