diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index bcd363df3..177e33a6e 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -428,8 +428,6 @@ def decode_dataset( The first is the reference transcript, and the second is the predicted result. """ - results = [] - num_cuts = 0 try: diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py index bcd363df3..c2567ed61 100755 --- a/egs/librispeech/ASR/transducer/decode.py +++ b/egs/librispeech/ASR/transducer/decode.py @@ -20,31 +20,22 @@ import argparse import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple -import k2 import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from conformer import Conformer +from transducer.beam_search import greedy_search +from transducer.conformer import Conformer +from transducer.decoder import Decoder +from transducer.joiner import Joiner +from transducer.model import Transducer -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.decode import ( - get_lattice, - nbest_decoding, - nbest_oracle, - one_best_decoding, - rescore_with_attention_decoder, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) from icefall.env import get_env_info -from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - get_texts, setup_logger, store_transcripts, write_error_stats, @@ -72,76 +63,18 @@ def get_parser(): "'--epoch'. ", ) - parser.add_argument( - "--method", - type=str, - default="attention-decoder", - help="""Decoding method. - Supported values are: - - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece - model, i.e., lang_dir/bpe.model, to convert word pieces to words. - It needs neither a lexicon nor an n-gram LM. - - (1) 1best. Extract the best path from the decoding lattice as the - decoding result. - - (2) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - - (3) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an n-gram LM (e.g., a 4-gram LM), the path with - the highest score is the decoding result. - - (4) whole-lattice-rescoring. Rescore the decoding lattice with an - n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice - is the decoding result. - - (5) attention-decoder. Extract n paths from the LM rescored - lattice, the path with the highest score is the decoding result. - - (6) nbest-oracle. Its WER is the lower bound of any n-best - rescoring method can achieve. Useful for debugging n-best - rescoring method. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help="""Number of paths for n-best based decoding method. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, and nbest-oracle - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""The scale to be applied to `lattice.scores`. - It's needed if you use any kinds of n-best based rescoring. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, and nbest-oracle - A smaller value results in more unique paths. - """, - ) - parser.add_argument( "--exp-dir", type=str, - default="conformer_ctc/exp", + default="transducer/exp", help="The experiment dir", ) parser.add_argument( - "--lang-dir", + "--bpe-model", type=str, - default="data/lang_bpe_500", - help="The lang dir", - ) - - parser.add_argument( - "--lm-dir", - type=str, - default="data/lm", - help="""The LM dir. - It should contain either G_4_gram.pt or G_4_gram.fst.txt - """, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", ) return parser @@ -151,250 +84,138 @@ def get_params() -> AttributeDict: params = AttributeDict( { # parameters for conformer + "feature_dim": 80, + "encoder_out_dim": 512, "subsampling_factor": 4, + "attention_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, "vgg_frontend": False, "use_feat_batchnorm": True, - "feature_dim": 80, - "nhead": 8, - "attention_dim": 512, - "num_decoder_layers": 6, - # parameters for decoding - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, + # decoder params + "decoder_embedding_dim": 1024, + "num_decoder_layers": 4, + "decoder_hidden_dim": 512, "env_info": get_env_info(), } ) return params +def get_encoder_model(params: AttributeDict): + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.encoder_out_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + return encoder + + +def get_decoder_model(params: AttributeDict): + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.decoder_embedding_dim, + blank_id=params.blank_id, + sos_id=params.sos_id, + num_layers=params.num_decoder_layers, + hidden_dim=params.decoder_hidden_dim, + output_dim=params.encoder_out_dim, + ) + return decoder + + +def get_joiner_model(params: AttributeDict): + joiner = Joiner( + input_dim=params.encoder_out_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict): + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + def decode_one_batch( params: AttributeDict, model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], + sp: spm.SentencePieceProcessor, batch: dict, - word_table: k2.SymbolTable, - sos_id: int, - eos_id: int, - G: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: - key: It indicates the setting used for decoding. For example, - if no rescoring is used, the key is the string `no_rescore`. - If LM rescoring is used, the key is the string `lm_scale_xxx`, - where `xxx` is the value of `lm_scale`. An example key is - `lm_scale_0.7` + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" - value: It contains the decoding result. `len(value)` equals to batch size. `value[i]` is the decoding result for the i-th utterance in the given batch. Args: params: It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding without LM rescoring. - - params.method is "nbest", it uses nbest decoding without LM rescoring. - - params.method is "nbest-rescoring", it uses nbest LM rescoring. - - params.method is "whole-lattice-rescoring", it uses whole lattice LM - rescoring. - model: The neural model. - HLG: - The decoding graph. Used only when params.method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.method is ctc-decoding. + sp: + The BPE model. batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. - word_table: - The word symbol table. - sos_id: - The token ID of the SOS. - eos_id: - The token ID of the EOS. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. Returns: Return the decoding result. See above description for the format of - the returned dict. Note: If it decodes to nothing, then return None. + the returned dict. """ - if HLG is not None: - device = HLG.device - else: - device = H.device + device = model.device feature = batch["inputs"] assert feature.ndim == 3 + feature = feature.to(device) # at entry, feature is (N, T, C) supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) - nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) - # nnet_output is (N, T, C) - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"] // params.subsampling_factor, - supervisions["num_frames"] // params.subsampling_factor, - ), - 1, - ).to(torch.int32) - - if H is None: - assert HLG is not None - decoding_graph = HLG - else: - assert HLG is None - assert bpe_model is not None - decoding_graph = H - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens ) + hyps = [] + batch_size = encoder_out.size(0) - if params.method == "ctc-decoding": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - token_ids = get_texts(best_path) + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + hyp = greedy_search(model=model, encoder_out=encoder_out_i) + hyps.append(sp.decode(hyp).split()) - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(token_ids) - - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - hyps = [s.split() for s in hyps] - key = "ctc-decoding" - return {key: hyps} - - if params.method == "nbest-oracle": - # Note: You can also pass rescored lattices to it. - # We choose the HLG decoded lattice for speed reasons - # as HLG decoding is faster and the oracle WER - # is only slightly worse than that of rescored lattices. - best_path = nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=word_table, - nbest_scale=params.nbest_scale, - oov="", - ) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa - return {key: hyps} - - if params.method in ["1best", "nbest"]: - if params.method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no_rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa - - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - return {key: hyps} - - assert params.method in [ - "nbest-rescoring", - "whole-lattice-rescoring", - "attention-decoder", - ] - - lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] - lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] - - if params.method == "nbest-rescoring": - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - nbest_scale=params.nbest_scale, - ) - elif params.method == "whole-lattice-rescoring": - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=lm_scale_list, - ) - elif params.method == "attention-decoder": - # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=None, - ) - # TODO: pass `lattice` instead of `rescored_lattice` to - # `rescore_with_attention_decoder` - - best_path_dict = rescore_with_attention_decoder( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=sos_id, - eos_id=eos_id, - nbest_scale=params.nbest_scale, - ) - else: - assert False, f"Unsupported decoding method: {params.method}" - - ans = dict() - if best_path_dict is not None: - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps - else: - ans = None - return ans + return {"greedy_search": hyps} + # TODO: Implement beam search def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - word_table: k2.SymbolTable, - sos_id: int, - eos_id: int, - G: Optional[k2.Fsa] = None, + sp: spm.SentencePieceProcessor, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -405,31 +226,15 @@ def decode_dataset( It is returned by :func:`get_params`. model: The neural model. - HLG: - The decoding graph. Used only when params.method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.method is ctc-decoding. - word_table: - It is the word symbol table. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. + sp: + The BPE model. Returns: - Return a dict, whose key may be "no-rescore" if no LM rescoring - is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. Its value is a list of tuples. Each tuple contains two elements: The first is the reference transcript, and the second is the predicted result. """ - results = [] - num_cuts = 0 try: @@ -444,37 +249,18 @@ def decode_dataset( hyps_dict = decode_one_batch( params=params, model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, + sp=sp, batch=batch, - word_table=word_table, - G=G, - sos_id=sos_id, - eos_id=eos_id, ) - if hyps_dict is not None: - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): - ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + for name, hyps in hyps_dict.items(): this_batch = [] - hyp_words = [] - for ref_text in texts: + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): ref_words = ref_text.split() this_batch.append((ref_words, hyp_words)) - for lm_scale in results.keys(): - results[lm_scale].extend(this_batch) + results[name].extend(this_batch) num_cuts += len(texts) @@ -492,31 +278,22 @@ def save_results( test_set_name: str, results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): - if params.method == "attention-decoder": - # Set it to False since there are too many logs. - enable_log = False - else: - enable_log = True test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") + logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=enable_log + f, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer - if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -539,113 +316,31 @@ def main(): LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - args.lm_dir = Path(args.lm_dir) params = get_params() params.update(vars(args)) - setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + setup_logger(f"{params.exp_dir}/log-decode") logging.info("Decoding started") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) - logging.info(f"device: {device}") + logging.info(f"Device: {device}") - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) - sos_id = graph_compiler.sos_id - eos_id = graph_compiler.eos_id + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) - if params.method == "ctc-decoding": - HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir / "bpe.model")) - else: - H = None - bpe_model = None - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) - ) - assert HLG.requires_grad is False + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.sos_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() + logging.info(params) - if params.method in ( - "nbest-rescoring", - "whole-lattice-rescoring", - "attention-decoder", - ): - if not (params.lm_dir / "G_4_gram.pt").is_file(): - logging.info("Loading G_4_gram.fst.txt") - logging.warning("It may take 8 minutes.") - with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.word_table["#0"] - - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - # G.aux_labels is not needed in later computations, so - # remove it here. - del G.aux_labels - # CAUTION: The following line is crucial. - # Arcs entering the back-off state have label equal to #0. - # We have to change it to 0 here. - G.labels[G.labels >= first_word_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set G.properties to None - G.__dict__["_properties"] = None - G = k2.Fsa.from_fsas([G]).to(device) - G = k2.arc_sort(G) - # Save a dummy value so that it can be loaded in C++. - # See https://github.com/pytorch/pytorch/issues/67902 - # for why we need to do this. - G.dummy = 1 - - torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") - else: - logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) - G = k2.Fsa.from_dict(d) - - if params.method in ["whole-lattice-rescoring", "attention-decoder"]: - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G = G.to(device) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - else: - G = None - - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, - ) + logging.info("About to create model") + model = get_transducer_model(params) if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) @@ -661,6 +356,8 @@ def main(): model.to(device) model.eval() + model.device = device + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -680,17 +377,13 @@ def main(): dl=test_dl, params=params, model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - word_table=lexicon.word_table, - G=G, - sos_id=sos_id, - eos_id=eos_id, + sp=sp, ) save_results( - params=params, test_set_name=test_set, results_dict=results_dict + params=params, + test_set_name=test_set, + results_dict=results_dict, ) logging.info("Done!")