diff --git a/egs/ami/ASR/zipformer/decode.py b/egs/ami/ASR/zipformer/decode.py index 3531d657f..be36d3609 100755 --- a/egs/ami/ASR/zipformer/decode.py +++ b/egs/ami/ASR/zipformer/decode.py @@ -106,24 +106,18 @@ import k2 import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import AmiAsrDataModule from beam_search import ( beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, greedy_search_batch, modified_beam_search, - modified_beam_search_lm_rescore, - modified_beam_search_lm_rescore_LODR, - modified_beam_search_lm_shallow_fusion, - modified_beam_search_LODR, ) from train import add_model_arguments, get_model, get_params -from icefall import ContextGraph, LmScorer, NgramLm +from icefall import LmScorer from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -133,7 +127,6 @@ from icefall.checkpoint import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - make_pad_mask, setup_logger, store_transcripts, str2bool, @@ -308,68 +301,6 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) - parser.add_argument( - "--use-shallow-fusion", - type=str2bool, - default=False, - help="""Use neural network LM for shallow fusion. - If you want to use LODR, you will also need to set this to true - """, - ) - - parser.add_argument( - "--lm-type", - type=str, - default="rnn", - help="Type of NN lm", - choices=["rnn", "transformer"], - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.3, - help="""The scale of the neural network LM - Used only when `--use-shallow-fusion` is set to True. - """, - ) - - parser.add_argument( - "--tokens-ngram", - type=int, - default=2, - help="""The order of the ngram lm. - """, - ) - - parser.add_argument( - "--backoff-id", - type=int, - default=500, - help="ID of the backoff symbol in the ngram LM", - ) - - parser.add_argument( - "--context-score", - type=float, - default=2, - help=""" - The bonus score of each token for the context biasing words/phrases. - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - - parser.add_argument( - "--context-file", - type=str, - default="", - help=""" - The path of the context biasing lists, one word/phrase each line - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) add_model_arguments(parser) return parser @@ -380,12 +311,8 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, - word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, + word_table: Optional[k2.SymbolTable] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -474,35 +401,6 @@ def decode_one_batch( ) for hyp in hyp_tokens: hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, @@ -517,55 +415,9 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": - hyp_tokens = modified_beam_search_lm_shallow_fusion( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_LODR": - hyp_tokens = modified_beam_search_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LODR_lm=ngram_lm, - LODR_lm_scale=ngram_lm_scale, - LM=LM, - context_graph=context_graph, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_rescore": - lm_scale_list = [0.01 * i for i in range(10, 50)] - ans_dict = modified_beam_search_lm_rescore( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - lm_scale_list=lm_scale_list, - ) - elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": - lm_scale_list = [0.02 * i for i in range(2, 30)] - ans_dict = modified_beam_search_lm_rescore_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - LODR_lm=ngram_lm, - sp=sp, - lm_scale_list=lm_scale_list, - ) else: batch_size = encoder_out.size(0) @@ -593,6 +445,14 @@ def decode_one_batch( if params.decoding_method == "greedy_search": return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" key += f"max_contexts_{params.max_contexts}_" @@ -604,22 +464,6 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} - elif "modified_beam_search" in params.decoding_method: - prefix = f"beam_size_{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ): - ans = dict() - assert ans_dict is not None - for key, hyps in ans_dict.items(): - hyps = [sp.decode(hyp).split() for hyp in hyps] - ans[f"{prefix}_{key}"] = hyps - return ans - else: - if params.has_contexts: - prefix += f"-context-score-{params.context_score}" - return {prefix: hyps} else: return {f"beam_size_{params.beam_size}": hyps} @@ -629,12 +473,8 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, + word_table: Optional[k2.SymbolTable] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -682,12 +522,8 @@ def decode_dataset( model=model, sp=sp, decoding_graph=decoding_graph, - context_graph=context_graph, word_table=word_table, batch=batch, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, ) for name, hyps in hyps_dict.items(): @@ -755,7 +591,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + AmiAsrDataModule.add_arguments(parser) LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -767,37 +603,16 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", - "fast_beam_search_nbest", "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", "modified_beam_search", - "modified_beam_search_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", ) params.res_dir = params.exp_dir / params.decoding_method - if os.path.exists(params.context_file): - params.has_contexts = True - else: - params.has_contexts = False - if params.iter > 0: params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -809,27 +624,10 @@ def main(): params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search", - "modified_beam_search_LODR", - ): - if params.has_contexts: - params.suffix += f"-context-score-{params.context_score}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - if params.use_shallow_fusion: - params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" - - if "LODR" in params.decoding_method: - params.suffix += ( - f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" - ) - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -932,54 +730,6 @@ def main(): model.to(device) model.eval() - # only load the neural network LM if required - if params.use_shallow_fusion or params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - ): - LM = LmScorer( - lm_type=params.lm_type, - params=params, - device=device, - lm_scale=params.lm_scale, - ) - LM.to(device) - LM.eval() - else: - LM = None - - # only load N-gram LM when needed - if params.decoding_method == "modified_beam_search_lm_rescore_LODR": - try: - import kenlm - except ImportError: - print("Please install kenlm first. You can use") - print(" pip install https://github.com/kpu/kenlm/archive/master.zip") - print("to install it") - import sys - - sys.exit(-1) - ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") - logging.info(f"lm filename: {ngram_file_name}") - ngram_lm = kenlm.Model(ngram_file_name) - ngram_lm_scale = None # use a list to search - - elif params.decoding_method == "modified_beam_search_LODR": - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"Loading token level lm: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") - ngram_lm_scale = params.ngram_lm_scale - else: - ngram_lm = None - ngram_lm_scale = None - if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) @@ -997,46 +747,51 @@ def main(): decoding_graph = None word_table = None - if "modified_beam_search" in params.decoding_method: - if os.path.exists(params.context_file): - contexts = [] - for line in open(params.context_file).readlines(): - contexts.append(line.strip()) - context_graph = ContextGraph(params.context_score) - context_graph.build(sp.encode(contexts)) - else: - context_graph = None - else: - context_graph = None - num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") # we need cut ids to display recognition results. args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + ami = AmiAsrDataModule(args) - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) + dev_ihm_cuts = ami.dev_ihm_cuts() + test_ihm_cuts = ami.test_ihm_cuts() + dev_sdm_cuts = ami.dev_sdm_cuts() + test_sdm_cuts = ami.test_sdm_cuts() + dev_gss_cuts = ami.dev_gss_cuts() + test_gss_cuts = ami.test_gss_cuts() - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + dev_ihm_dl = ami.test_dataloaders(dev_ihm_cuts) + test_ihm_dl = ami.test_dataloaders(test_ihm_cuts) + dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts) + test_sdm_dl = ami.test_dataloaders(test_sdm_cuts) + if dev_gss_cuts is not None: + dev_gss_dl = ami.test_dataloaders(dev_gss_cuts) + if test_gss_cuts is not None: + test_gss_dl = ami.test_dataloaders(test_gss_cuts) - for test_set, test_dl in zip(test_sets, test_dl): + test_sets = { + "dev_ihm": (dev_ihm_dl, dev_ihm_cuts), + "test_ihm": (test_ihm_dl, test_ihm_cuts), + "dev_sdm": (dev_sdm_dl, dev_sdm_cuts), + "test_sdm": (test_sdm_dl, test_sdm_cuts), + } + if dev_gss_cuts is not None: + test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts) + if test_gss_cuts is not None: + test_sets["test_gss"] = (test_gss_dl, test_gss_cuts) + + for test_set in test_sets: + logging.info(f"Decoding {test_set}") + dl, cuts = test_sets[test_set] results_dict = decode_dataset( - dl=test_dl, + dl=dl, params=params, model=model, sp=sp, word_table=word_table, decoding_graph=decoding_graph, - context_graph=context_graph, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, ) save_results(