From b8dbad5156e41ecfc848feaf7c218dca9cee823e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 18 Oct 2021 14:38:07 +0800 Subject: [PATCH] Update decoding script. --- egs/librispeech/ASR/conformer_mmi/decode.py | 125 ++++++++++++++---- .../ASR/conformer_mmi/train-with-attention.py | 2 +- 2 files changed, 97 insertions(+), 30 deletions(-) diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index dc2e449c2..79012c98f 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 +import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule @@ -77,6 +78,9 @@ def get_parser(): default="attention-decoder", help="""Decoding method. Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. - (1) 1best. Extract the best path from the decoding lattice as the decoding result. - (2) nbest. Extract n paths from the decoding lattice; the path @@ -106,7 +110,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help="""The scale to be applied to `lattice.scores`. @@ -122,7 +126,7 @@ def get_parser(): type=str2bool, default=False, help="""When enabled, the averaged model is saved to - conformer_mmi/exp/pretrained.pt. Note: only model.state_dict() is saved. + conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved. pretrained.pt contains a dict {"model": model.state_dict()}, which can be loaded by `icefall.checkpoint.load_checkpoint()`. """, @@ -131,17 +135,24 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="conformer_mmi/exp", + default="conformer_mmi/exp_500", help="The experiment dir", ) parser.add_argument( "--lang-dir", type=str, - default="data/lang_bpe", + default="data/lang_bpe_500", help="The lang dir", ) + parser.add_argument( + "--num-decoder-layers", + type=int, + default=6, + help="Number of attention decoder layers", + ) + return parser @@ -156,7 +167,6 @@ def get_params() -> AttributeDict: "feature_dim": 80, "nhead": 8, "attention_dim": 512, - "num_decoder_layers": 6, # parameters for decoding "search_beam": 20, "output_beam": 8, @@ -171,13 +181,15 @@ def get_params() -> AttributeDict: def decode_one_batch( params: AttributeDict, model: nn.Module, - HLG: k2.Fsa, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], batch: dict, word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[int]]]: +) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -202,7 +214,11 @@ def decode_one_batch( model: The neural model. HLG: - The decoding graph. + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation @@ -221,7 +237,10 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. """ - device = HLG.device + if HLG is not None: + device = HLG.device + else: + device = H.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) @@ -241,9 +260,17 @@ def decode_one_batch( 1, ).to(torch.int32) + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=decoding_graph, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, @@ -252,6 +279,24 @@ def decode_one_batch( subsampling_factor=params.subsampling_factor, ) + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + if params.method == "nbest-oracle": # Note: You can also pass rescored lattices to it. # We choose the HLG decoded lattice for speed reasons @@ -262,12 +307,12 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=supervisions["text"], word_table=word_table, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, oov="", ) hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa return {key: hyps} if params.method in ["1best", "nbest"]: @@ -281,9 +326,9 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) - key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] @@ -305,7 +350,7 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) elif params.method == "whole-lattice-rescoring": best_path_dict = rescore_with_whole_lattice( @@ -331,7 +376,7 @@ def decode_one_batch( memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, eos_id=eos_id, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -344,7 +389,7 @@ def decode_one_batch( ans[lm_scale_str] = hyps else: for lm_scale in lm_scale_list: - ans[lm_scale_str] = [[] * lattice.shape[0]] + ans["empty"] = [[] * lattice.shape[0]] return ans @@ -352,12 +397,14 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - HLG: k2.Fsa, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[int], List[int]]]]: +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: @@ -368,7 +415,11 @@ def decode_dataset( model: The neural model. HLG: - The decoding graph. + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. word_table: It is the word symbol table. sos_id: @@ -403,6 +454,8 @@ def decode_dataset( params=params, model=model, HLG=HLG, + H=H, + bpe_model=bpe_model, batch=batch, word_table=word_table, G=G, @@ -481,11 +534,11 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) - params.exp_dir = Path(params.exp_dir) - params.lang_dir = Path(params.lang_dir) setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") logging.info("Decoding started") @@ -510,14 +563,26 @@ def main(): sos_id = graph_compiler.sos_id eos_id = graph_compiler.eos_id - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) - HLG = HLG.to(device) - assert HLG.requires_grad is False + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) + HLG = HLG.to(device) + assert HLG.requires_grad is False - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() if params.method in ( "nbest-rescoring", @@ -607,6 +672,8 @@ def main(): params=params, model=model, HLG=HLG, + H=H, + bpe_model=bpe_model, word_table=lexicon.word_table, G=G, sos_id=sos_id, diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py index a66f77676..a3e266814 100755 --- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py +++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py @@ -373,7 +373,7 @@ def compute_loss( params.batch_idx_train > params.use_ali_until and params.beam_size < 8 ): - logging.info("Change beam size to 8") + # logging.info("Change beam size to 8") params.beam_size = 8 else: params.beam_size = 6