From 9cbe54732a0a9c7eab126902f95f93927e3baa4e Mon Sep 17 00:00:00 2001 From: yfy62 Date: Wed, 26 Apr 2023 17:35:19 +0800 Subject: [PATCH 1/6] Add phone based train and decode for gigaspeech --- .../asr_datamodule.py | 6 +- .../pruned_transducer_stateless2/decode.py | 508 ++++++++++++++---- .../ASR/pruned_transducer_stateless2/train.py | 129 ++++- 3 files changed, 522 insertions(+), 121 deletions(-) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 4d5d2b8f9..1b1e1e5bb 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -387,7 +387,7 @@ class GigaSpeechAsrDataModule: @lru_cache() def train_cuts(self) -> CutSet: logging.info(f"About to get train_{self.args.subset} cuts") - path = self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz" + path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz" cuts_train = CutSet.from_jsonl_lazy(path) return cuts_train @@ -395,7 +395,7 @@ class GigaSpeechAsrDataModule: def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" + self.args.manifest_dir / "cuts_DEV.jsonl.gz" ) if self.args.small_dev: return cuts_valid.subset(first=1000) @@ -406,5 +406,5 @@ class GigaSpeechAsrDataModule: def test_cuts(self) -> CutSet: logging.info("About to get test cuts") return load_manifest_lazy( - self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" + self.args.manifest_dir / "cuts_TEST.jsonl.gz" ) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 72f74c968..1c1bb1693 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -24,8 +24,7 @@ Usage: --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method greedy_search - -(2) beam search +(2) beam search (not recommended) ./pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ @@ -33,7 +32,6 @@ Usage: --max-duration 600 \ --decoding-method beam_search \ --beam-size 4 - (3) modified beam search ./pruned_transducer_stateless2/decode.py \ --epoch 28 \ @@ -42,17 +40,60 @@ Usage: --max-duration 600 \ --decoding-method modified_beam_search \ --beam-size 4 - -(4) fast beam search +(4) fast beam search (one best) ./pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +(5) fast beam search (nbest) +./pruned_transducer_stateless2/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless2/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 +(7) fast beam search (with LG) +./pruned_transducer_stateless2/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +(8) fast beam search (nbest with LG) +./pruned_transducer_stateless2/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 """ @@ -69,6 +110,9 @@ import torch.nn as nn from asr_datamodule import GigaSpeechAsrDataModule from beam_search import ( beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, greedy_search_batch, @@ -76,12 +120,14 @@ from beam_search import ( ) from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model + from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) +from icefall.lexicon import UniqLexicon from icefall.utils import ( AttributeDict, setup_logger, @@ -127,7 +173,7 @@ def get_parser(): parser.add_argument( "--use-averaged-model", type=str2bool, - default=False, + default=True, help="Whether to load averaged model. Currently it only supports " "using --epoch. If True, it would decode with the averaged model " "over the epoch range from `epoch-avg` (excluded) to `epoch`." @@ -145,10 +191,17 @@ def get_parser(): parser.add_argument( "--bpe-model", type=str, - default="data/lang_bpe_500/bpe.model", + default=None, help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_phone", + help="The lang dir contains word table and LG graph", + ) + parser.add_argument( "--decoding-method", type=str, @@ -158,6 +211,20 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + """, + ) + + parser.add_argument( + "--metrics", + type=str, + default="WER", + help="""Possible values are: + - WER + - PER """, ) @@ -173,27 +240,45 @@ def get_parser(): parser.add_argument( "--beam", type=float, - default=4, + default=20.0, help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", + Used only when --decoding-method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_LG or + fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, ) parser.add_argument( "--max-contexts", type=int, - default=4, + default=8, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, + fast_beam_search_nbest_LG and fast_beam_search_nbest_oracle + """, ) parser.add_argument( "--max-states", type=int, - default=8, + default=64, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, + fast_beam_search_nbest_LG and fast_beam_search_nbest_oracle + """, ) parser.add_argument( @@ -202,6 +287,7 @@ def get_parser(): default=2, help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) + parser.add_argument( "--max-sym-per-frame", type=int, @@ -210,6 +296,24 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + return parser @@ -228,7 +332,9 @@ def decode_one_batch( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + pl: UniqLexicon, batch: dict, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the @@ -248,18 +354,24 @@ def decode_one_batch( The neural model. sp: The BPE model. + pl: + The phone lexicon. batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, fast_beam_search_nbest_oracle + and fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of the returned dict. """ - device = model.device + device = next(model.parameters()).device feature = batch["inputs"] assert feature.ndim == 3 @@ -272,7 +384,10 @@ def decode_one_batch( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -282,6 +397,58 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) + + if params.decoding_method == "fast_beam_search": + if sp is not None: + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([str(i) for i in hyp]) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: @@ -290,8 +457,12 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if sp is not None: + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([str(i) for i in hyp]) elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -299,8 +470,12 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if sp is not None: + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([str(i) for i in hyp]) else: batch_size = encoder_out.size(0) @@ -324,18 +499,24 @@ def decode_one_batch( raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append(sp.decode(hyp).split()) + if sp is not None: + hyps.append(sp.decode(hyp).split()) + else: + hyps.append([str(i) for i in hyp]) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} else: return {f"beam_size_{params.beam_size}": hyps} @@ -345,6 +526,8 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + pl: UniqLexicon, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -358,9 +541,15 @@ def decode_dataset( The neural model. sp: The BPE model. + pl: + The phone lexicon. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, fast_beam_search_nbest_oracle + and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -375,29 +564,82 @@ def decode_dataset( except TypeError: num_batches = "?" - log_interval = 20 + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + if sp is not None: + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - batch=batch, - ) + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + pl=pl, + word_table=word_table, + decoding_graph=decoding_graph, + batch=batch, + ) - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) - results[name].extend(this_batch) + results[name].extend(this_batch) + else: + if params.metrics == "WER": + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + pl=pl, + word_table=word_table, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + elif params.metrics == "PER": + texts = batch["supervisions"]["text"] + token_ids = pl.texts_to_token_ids(texts).tolist() + + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + pl=pl, + word_table=word_table, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(token_ids) + for cut_id, hyp_id, ref_token_id in zip(cut_ids, hyps, token_ids): + ref_token_id = [str(i) for i in ref_token_id] + this_batch.append((cut_id, ref_token_id, hyp_id)) + + results[name].extend(this_batch) num_cuts += len(texts) @@ -413,38 +655,73 @@ def save_results( test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = post_processing(results) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") + if params.metrics == "WER": + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) + elif params.metrics == "PER": + test_set_pers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out PERs, per-phone error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + per = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_pers[key] = per + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_pers = sorted(test_set_pers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"per-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tPER", file=f) + for key, val in test_set_pers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, PER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_pers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) @torch.no_grad() @@ -457,12 +734,32 @@ def main(): params = get_params() params.update(vars(args)) - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - ) + if params.bpe_model is not None: + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + else: + if params.metrics == "PER": + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + "fast_beam_search", + ), "Decoding method without L or LG must use PER" + elif params.metrics == "WER": + assert params.decoding_method in ( + "fast_beam_search_LG", + "fast_beam_search_LG", + "fast_beam_search_nbest_LG", + ), "Decoding method with L or LG must use WER" + params.res_dir = params.exp_dir / params.decoding_method if params.iter > 0: @@ -474,8 +771,13 @@ def main(): params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -492,13 +794,19 @@ def main(): logging.info(f"Device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + if params.bpe_model is not None: + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + else: + pl = UniqLexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(pl.tokens) + 1 + sp = None logging.info(params) @@ -585,10 +893,24 @@ def main(): model.to(device) model.eval() - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + if "fast_beam_search" in params.decoding_method: + if ( + params.decoding_method == "fast_beam_search_LG" + or params.decoding_method == "fast_beam_search_nbest_LG" + ): + word_table = pl.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None + word_table = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -612,6 +934,8 @@ def main(): params=params, model=model, sp=sp, + pl=pl, + word_table=word_table, decoding_graph=decoding_graph, ) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 578bd9218..2ce785ef8 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Yifan Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -19,13 +20,15 @@ """ Usage: +(1) bpe export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" ./pruned_transducer_stateless2/train.py \ --world-size 8 \ --num-epochs 30 \ - --start-epoch 0 \ + --start-epoch 1 \ --exp-dir pruned_transducer_stateless2/exp \ + --subset XL \ --max-duration 120 # For mix precision training: @@ -33,11 +36,37 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" ./pruned_transducer_stateless2/train.py \ --world-size 8 \ --num-epochs 30 \ - --start-epoch 0 \ - --use_fp16 1 \ + --start-epoch 1 \ + --use-fp16 1 \ --exp-dir pruned_transducer_stateless2/exp \ + --subset XL \ --max-duration 200 +(2) phone +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./pruned_transducer_stateless2/train.py \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless2/exp \ + --subset XL \ + --lang-type phone \ + --context-size 4 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless2/train.py \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless2/exp \ + --subset XL \ + --lang-type phone \ + --context-size 4 \ + --max-duration 750 """ @@ -77,6 +106,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.lexicon import UniqLexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -119,8 +149,8 @@ def get_parser(): "--start-epoch", type=int, default=1, - help="""Resume training from this epoch. - If larger than 1, it will load checkpoint from + help="""Resume training from from this epoch. + If it is large than 1, it will load checkpoint from exp-dir/epoch-{start_epoch-1}.pt """, ) @@ -144,6 +174,13 @@ def get_parser(): """, ) + parser.add_argument( + "--lang-type", + type=str, + default="bpe", + help="Either bpe or phone", + ) + parser.add_argument( "--bpe-model", type=str, @@ -151,6 +188,13 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_phone", + help="the lang dir contains lexicon", + ) + parser.add_argument( "--initial-lr", type=float, @@ -231,7 +275,7 @@ def get_parser(): parser.add_argument( "--save-every-n", type=int, - default=8000, + default=20000, help="""Save checkpoint after processing this number of batches" periodically. We save checkpoint to exp-dir/ whenever params.batch_idx_train % save_every_n == 0. The checkpoint filename @@ -244,7 +288,7 @@ def get_parser(): parser.add_argument( "--keep-last-k", type=int, - default=30, + default=20, help="""Only keep this number of checkpoints on disk. For instance, if it is 3, there are only 3 checkpoints in the exp-dir with filenames `checkpoint-xxx.pt`. @@ -261,7 +305,7 @@ def get_parser(): in which each floating-point parameter is the average of all the parameters from the start of training. Each time we take the average, we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. """, ) @@ -470,7 +514,7 @@ def load_checkpoint_if_available( def save_checkpoint( params: AttributeDict, - model: Union[nn.Module, DDP], + model: nn.Module, model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, @@ -520,14 +564,15 @@ def save_checkpoint( def compute_loss( params: AttributeDict, - model: Union[nn.Module, DDP], + model: nn.Module, sp: spm.SentencePieceProcessor, + pl: UniqLexicon, batch: dict, is_training: bool, warmup: float = 1.0, ) -> Tuple[Tensor, MetricsTracker]: """ - Compute transducer loss given the model and its inputs. + Compute CTC loss given the model and its inputs. Args: params: @@ -554,8 +599,12 @@ def compute_loss( feature_lens = supervisions["num_frames"].to(device) texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) + + if sp is not None: + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + else: + y = pl.texts_to_token_ids(texts).to(device) with torch.set_grad_enabled(is_training): simple_loss, pruned_loss = model( @@ -593,8 +642,9 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, - model: Union[nn.Module, DDP], + model: nn.Module, sp: spm.SentencePieceProcessor, + pl: UniqLexicon, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> MetricsTracker: @@ -608,6 +658,7 @@ def compute_validation_loss( params=params, model=model, sp=sp, + pl=pl, batch=batch, is_training=False, ) @@ -627,10 +678,11 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, - model: Union[nn.Module, DDP], + model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, + pl: UniqLexicon, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, @@ -688,8 +740,8 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, - model_avg=model_avg, sp=sp, + pl=pl, batch=batch, is_training=True, warmup=(params.batch_idx_train / params.model_warm_step), @@ -708,6 +760,17 @@ def train_one_epoch( if params.print_diagnostics and batch_idx == 30: return + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + if ( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 @@ -757,6 +820,7 @@ def train_one_epoch( params=params, model=model, sp=sp, + pl=pl, valid_dl=valid_dl, world_size=world_size, ) @@ -806,12 +870,21 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + if params.lang_type == "bpe": + logging.info(f"Using bpe model") + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + pl = None + elif params.lang_type == "phone": + logging.info(f"Using phone lexion") + pl = UniqLexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(pl.tokens) + 1 + sp = None logging.info(params) @@ -875,12 +948,13 @@ def run(rank, world_size, args): valid_cuts = gigaspeech.dev_cuts() valid_dl = gigaspeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if 0 and not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, optimizer=optimizer, sp=sp, + pl=pl, params=params, ) @@ -906,6 +980,7 @@ def run(rank, world_size, args): optimizer=optimizer, scheduler=scheduler, sp=sp, + pl=pl, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, @@ -937,10 +1012,11 @@ def run(rank, world_size, args): def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], + model: nn.Module, train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, + pl: UniqLexicon, params: AttributeDict, ): from lhotse.dataset import find_pessimistic_batches @@ -960,6 +1036,7 @@ def scan_pessimistic_batches_for_oom( params=params, model=model, sp=sp, + pl=pl, batch=batch, is_training=True, warmup=0.0, From c6b4159dccb015a246a1be79cfce8748365e0e54 Mon Sep 17 00:00:00 2001 From: yfy62 Date: Wed, 26 Apr 2023 17:39:07 +0800 Subject: [PATCH 2/6] Update type hint --- egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 2ce785ef8..09a3b6a1b 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -514,7 +514,7 @@ def load_checkpoint_if_available( def save_checkpoint( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, @@ -642,7 +642,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, pl: UniqLexicon, valid_dl: torch.utils.data.DataLoader, @@ -1012,7 +1012,7 @@ def run(rank, world_size, args): def scan_pessimistic_batches_for_oom( - model: nn.Module, + model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, From 0b5996bd3d0fcd5bb4698c7c4bfd676f52883284 Mon Sep 17 00:00:00 2001 From: yfy62 Date: Wed, 26 Apr 2023 17:40:10 +0800 Subject: [PATCH 3/6] Fix comments --- egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 09a3b6a1b..6f9a85c2d 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -572,7 +572,7 @@ def compute_loss( warmup: float = 1.0, ) -> Tuple[Tensor, MetricsTracker]: """ - Compute CTC loss given the model and its inputs. + Compute transducer loss given the model and its inputs. Args: params: From 23a9b662950c6ead010ea45ba07515a4d30857d4 Mon Sep 17 00:00:00 2001 From: yfy62 Date: Wed, 26 Apr 2023 17:43:11 +0800 Subject: [PATCH 4/6] Fix type hint --- egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 6f9a85c2d..32998a8de 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -678,7 +678,7 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, From 5e7e1a350eaa0ec73efabd918bcaedce36b44e5e Mon Sep 17 00:00:00 2001 From: yfy62 Date: Wed, 26 Apr 2023 17:45:00 +0800 Subject: [PATCH 5/6] Fit type hint --- egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 32998a8de..8edbf8500 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -564,7 +564,7 @@ def save_checkpoint( def compute_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, pl: UniqLexicon, batch: dict, @@ -948,7 +948,7 @@ def run(rank, world_size, args): valid_cuts = gigaspeech.dev_cuts() valid_dl = gigaspeech.valid_dataloaders(valid_cuts) - if 0 and not params.print_diagnostics: + if not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, From ce33bf432c36b277b3e86ed03c62f4e9e979b525 Mon Sep 17 00:00:00 2001 From: Yifan Yang Date: Thu, 27 Apr 2023 14:44:42 +0800 Subject: [PATCH 6/6] Fix for style check --- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 1b1e1e5bb..5c01d7190 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -394,9 +394,7 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "cuts_DEV.jsonl.gz" - ) + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") if self.args.small_dev: return cuts_valid.subset(first=1000) else: @@ -405,6 +403,4 @@ class GigaSpeechAsrDataModule: @lru_cache() def test_cuts(self) -> CutSet: logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_TEST.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")