diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py index af54af8da..be58c4e43 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py @@ -58,6 +58,7 @@ Usage: import argparse import logging +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -76,6 +77,8 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -211,6 +214,26 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding_method is modified_beam_search. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding_method is modified_beam_search. + """, + ) + add_model_arguments(parser) return parser @@ -222,6 +245,7 @@ def decode_one_batch( token_table: k2.SymbolTable, batch: dict, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -285,6 +309,7 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + context_graph=context_graph, ) else: hyp_tokens = [] @@ -324,7 +349,12 @@ def decode_one_batch( ): hyps } else: - return {f"beam_size_{params.beam_size}": hyps} + key = f"beam_size_{params.beam_size}" + if params.has_contexts: + key += f"-context-score-{params.context_score}" + else: + key += "-no-context-words" + return {key: hyps} def decode_dataset( @@ -333,6 +363,7 @@ def decode_dataset( model: nn.Module, token_table: k2.SymbolTable, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -377,6 +408,7 @@ def decode_dataset( model=model, token_table=token_table, decoding_graph=decoding_graph, + context_graph=context_graph, batch=batch, ) @@ -407,16 +439,17 @@ def save_results( for key, results in results_dict.items(): recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + # we compute CER for aishell dataset. + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + + store_transcripts(filename=recog_path, texts=results_char) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -457,6 +490,12 @@ def main(): "fast_beam_search", "modified_beam_search", ) + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + params.res_dir = params.exp_dir / params.decoding_method if params.iter > 0: @@ -470,6 +509,10 @@ def main(): params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += "-no-contexts-words" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -490,6 +533,11 @@ def main(): params.blank_id = 0 params.vocab_size = max(lexicon.tokens) + 1 + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + logging.info(params) logging.info("About to create model") @@ -586,6 +634,19 @@ def main(): else: decoding_graph = None + if params.decoding_method == "modified_beam_search": + if os.path.exists(params.context_file): + contexts_text = [] + for line in open(params.context_file).readlines(): + contexts_text.append(line.strip()) + contexts = graph_compiler.texts_to_ids(contexts_text) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -608,6 +669,7 @@ def main(): model=model, token_table=lexicon.token_table, decoding_graph=decoding_graph, + context_graph=context_graph, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 1ea134c12..25b79d600 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import warnings from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Union diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 6aacd7f92..524366068 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -131,8 +131,6 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 -import kaldifst -import graphviz import sentencepiece as spm import torch import torch.nn as nn @@ -576,7 +574,10 @@ def decode_one_batch( return {key: (hyps, timestamps)} else: key = f"beam_size_{params.beam_size}" - key += f"-context-score-{params.context_score}" + if params.has_contexts: + key += f"-context-score-{params.context_score}" + else: + key += "-no-context-words" return {key: (hyps, timestamps)} @@ -626,7 +627,7 @@ def decode_dataset( if params.decoding_method == "greedy_search": log_interval = 50 else: - log_interval = 1 + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -759,6 +760,12 @@ def main(): "fast_beam_search_nbest_oracle", "modified_beam_search", ) + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + params.res_dir = params.exp_dir / params.decoding_method if params.iter > 0: @@ -781,7 +788,10 @@ def main(): params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - params.suffix += f"-context-score-{params.context_score}" + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += "-no-context-words" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -938,14 +948,8 @@ def main(): test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) test_other_dl = librispeech.test_dataloaders(test_other_cuts) - test_book_cuts = librispeech.test_book_cuts() - test_book_dl = librispeech.test_dataloaders(test_book_cuts) - - test_book2_cuts = librispeech.test_book2_cuts() - test_book2_dl = librispeech.test_dataloaders(test_book2_cuts) - - test_sets = ["test-book", "test-book2", "test-clean", "test-other"] - test_dl = [test_book_dl, test_book2_dl, test_clean_dl, test_other_dl] + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index c21c39322..c47964b07 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -389,7 +389,6 @@ class LibriSpeechAsrDataModule: ) sampler = DynamicBucketingSampler( cuts, - num_buckets=2, max_duration=self.args.max_duration, shuffle=False, ) @@ -468,25 +467,6 @@ class LibriSpeechAsrDataModule: self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" ) - @lru_cache() - def test_book_cuts(self) -> CutSet: - logging.info("About to get test-books cuts") - return load_manifest_lazy(self.args.manifest_dir / "libri_books_feats.jsonl.gz") - - @lru_cache() - def test_book_test_cuts(self) -> CutSet: - logging.info("About to get test-books cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libri_book_test_feats.jsonl.gz" - ) - - @lru_cache() - def test_book2_cuts(self) -> CutSet: - logging.info("About to get test-books2 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libri_books2_feats.jsonl.gz" - ) - @lru_cache() def test_other_cuts(self) -> CutSet: logging.info("About to get test-other cuts") diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 8c18e94f9..7cb2e1048 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -396,21 +396,14 @@ class WenetSpeechAsrDataModule: @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV2.jsonl.gz") + return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") @lru_cache() def test_net_cuts(self) -> List[CutSet]: logging.info("About to get TEST_NET cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET2.jsonl.gz") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz") @lru_cache() def test_meeting_cuts(self) -> List[CutSet]: logging.info("About to get TEST_MEETING cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_TEST_MEETING2.jsonl.gz" - ) - - @lru_cache() - def test_car_cuts(self) -> List[CutSet]: - logging.info("About to get TEST_CAR cuts") - return load_manifest_lazy(self.args.manifest_dir / "car_test_feats.jsonl.gz") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz") diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 7d0f987bd..dc431578c 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -533,9 +533,12 @@ def decode_one_batch( ): hyps } else: - return { - f"beam_size_{params.beam_size}_context_score_{params.context_score}": hyps - } + key = f"beam_size_{params.beam_size}" + if params.has_contexts: + key += f"-context-score-{params.context_score}" + else: + key += "-no-context-words" + return {key: hyps} def decode_dataset( @@ -674,6 +677,12 @@ def main(): "modified_beam_search_lm_shallow_fusion", "modified_beam_search_LODR", ) + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" @@ -683,7 +692,10 @@ def main(): params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" - params.suffix += f"-context-score-{params.context_score}" + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += "-no-contexts-words" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -851,14 +863,10 @@ def main(): if params.decoding_method == "modified_beam_search": if os.path.exists(params.context_file): - contexts = [] + contexts_text = [] for line in open(params.context_file).readlines(): - context_list = graph_compiler.texts_to_ids(line.strip()) - tmp = [] - for context in context_list: - for x in context: - tmp.append(x) - contexts.append(tmp) + contexts_text.append(line.strip()) + contexts = graph_compiler.texts_to_ids(contexts_text) context_graph = ContextGraph(params.context_score) context_graph.build(contexts) else: @@ -882,11 +890,8 @@ def main(): test_meeting_cuts = wenetspeech.test_meeting_cuts() test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) - test_car_cuts = wenetspeech.test_car_cuts() - test_car_dl = wenetspeech.test_dataloaders(test_car_cuts) - - test_sets = ["CAR", "TEST_NET", "DEV", "TEST_MEETING"] - test_dls = [test_car_dl, test_net_dl, dev_dl, test_meeting_dl] + test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] + test_dls = [dev_dl, test_net_dl, test_meeting_dl] for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset(