From ca8ed842f760bf52c5bab5685b549d5d5445bff8 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 22 Mar 2023 19:45:45 +0800 Subject: [PATCH] Add context biasing for wenetspeech --- .../pruned_transducer_stateless4/decode.py | 9 +- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 7 + .../asr_datamodule.py | 38 ++-- .../pruned_transducer_stateless5/decode.py | 189 +++++++++++------- icefall/context_graph.py | 51 ++++- 5 files changed, 198 insertions(+), 96 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index eb22daefe..afd3a9e0e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -913,7 +913,7 @@ def main(): for line in open(params.context_file).readlines(): contexts.append(line.strip()) context_graph = ContextGraph(params.context_score) - context_graph.build_context_graph(contexts, sp) + context_graph.build_context_graph_bpe(contexts, sp) else: context_graph = None else: @@ -935,8 +935,11 @@ def main(): test_book_cuts = librispeech.test_book_cuts() test_book_dl = librispeech.test_dataloaders(test_book_cuts) - test_sets = ["test-book", "test-clean", "test-other"] - test_dl = [test_book_dl, test_clean_dl, test_other_dl] + test_book2_cuts = librispeech.test_book2_cuts() + test_book2_dl = librispeech.test_dataloaders(test_book2_cuts) + + test_sets = ["test-book", "test-book2", "test-clean", "test-other"] + test_dl = [test_book_dl, test_book2_dl, test_clean_dl, test_other_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index ac90daafe..41698d00a 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -452,6 +452,13 @@ class LibriSpeechAsrDataModule: self.args.manifest_dir / "libri_books_feats.jsonl.gz" ) + @lru_cache() + def test_book2_cuts(self) -> CutSet: + logging.info("About to get test-books2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libri_books2_feats.jsonl.gz" + ) + @lru_cache() def test_other_cuts(self) -> CutSet: logging.info("About to get test-other cuts") diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 9c07263a2..fc8039fe3 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -46,8 +46,8 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool -set_caching_enabled(False) -torch.set_num_threads(1) +# set_caching_enabled(False) +# torch.set_num_threads(1) class _SeedWorkers: @@ -109,7 +109,7 @@ class WenetSpeechAsrDataModule: group.add_argument( "--num-buckets", type=int, - default=300, + default=30, help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) @@ -373,7 +373,7 @@ class WenetSpeechAsrDataModule: return valid_dl def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") + logging.info("About to create test dataset") test = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats @@ -383,19 +383,22 @@ class WenetSpeechAsrDataModule: sampler = DynamicBucketingSampler( cuts, max_duration=self.args.max_duration, - rank=0, - world_size=1, + buffer_size=10000, + # rank=0, + # world_size=1, shuffle=False, ) - from lhotse.dataset.iterable_dataset import IterableDatasetWrapper + # from lhotse.dataset.iterable_dataset import IterableDatasetWrapper - test_iter_dataset = IterableDatasetWrapper( - dataset=test, - sampler=sampler, - ) + # test_iter_dataset = IterableDatasetWrapper( + # dataset=test, + # sampler=sampler, + # ) test_dl = DataLoader( - test_iter_dataset, + # test_iter_dataset, + test, batch_size=None, + sampler=sampler, num_workers=self.args.num_workers, ) return test_dl @@ -411,14 +414,19 @@ class WenetSpeechAsrDataModule: @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV2.jsonl.gz") @lru_cache() def test_net_cuts(self) -> List[CutSet]: logging.info("About to get TEST_NET cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET2.jsonl.gz") @lru_cache() def test_meeting_cuts(self) -> List[CutSet]: logging.info("About to get TEST_MEETING cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING2.jsonl.gz") + + @lru_cache() + def test_car_cuts(self) -> List[CutSet]: + logging.info("About to get TEST_CAR cuts") + return load_manifest_lazy(self.args.manifest_dir / "car_test_feats.jsonl.gz") diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 7bd1177bd..6a036a709 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -95,8 +95,10 @@ When training with the L subset, the streaming usage: import argparse +import glob import logging import math +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -114,6 +116,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import ContextGraph from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -277,6 +280,20 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + parser.add_argument( + "--context-score", + type=float, + default=2, + help="", + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help="", + ) + add_model_arguments(parser) return parser @@ -288,6 +305,7 @@ def decode_one_batch( lexicon: Lexicon, batch: dict, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -325,14 +343,13 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, @@ -371,6 +388,7 @@ def decode_one_batch( encoder_out=encoder_out, beam=params.beam_size, encoder_out_lens=encoder_out_lens, + context_graph=context_graph ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) @@ -410,7 +428,7 @@ def decode_one_batch( ): hyps } else: - return {f"beam_size_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}_context_score_{params.context_score}": hyps} def decode_dataset( @@ -419,6 +437,7 @@ def decode_dataset( model: nn.Module, lexicon: Lexicon, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -463,6 +482,7 @@ def decode_dataset( lexicon=lexicon, decoding_graph=decoding_graph, batch=batch, + context_graph=context_graph, ) for name, hyps in hyps_dict.items(): @@ -551,6 +571,7 @@ def main(): params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" + params.suffix += f"-context-score-{params.context_score}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -664,13 +685,23 @@ def main(): else: decoding_graph = None + if params.decoding_method == "modified_beam_search": + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(line.strip()) + context_graph = ContextGraph(params.context_score) + context_graph.build_context_graph_char(contexts, lexicon.token_table) + else: + context_graph = None + else: + context_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") # Note: Please use "pip install webdataset==0.1.103" # for installing the webdataset. - import glob - import os from lhotse import CutSet from lhotse.dataset.webdataset import export_to_webdataset @@ -679,82 +710,98 @@ def main(): args.return_cuts = True wenetspeech = WenetSpeechAsrDataModule(args) - dev = "dev" - test_net = "test_net" - test_meeting = "test_meeting" + #dev = "dev" + #test_net = "test_net" + #test_meeting = "test_meeting" - if not os.path.exists(f"{dev}/shared-0.tar"): - os.makedirs(dev) - dev_cuts = wenetspeech.valid_cuts() - export_to_webdataset( - dev_cuts, - output_path=f"{dev}/shared-%d.tar", - shard_size=300, - ) + #if not os.path.exists(f"{dev}/shared-0.tar"): + # os.makedirs(dev) + # dev_cuts = wenetspeech.valid_cuts() + # export_to_webdataset( + # dev_cuts, + # output_path=f"{dev}/shared-%d.tar", + # shard_size=300, + # ) - if not os.path.exists(f"{test_net}/shared-0.tar"): - os.makedirs(test_net) - test_net_cuts = wenetspeech.test_net_cuts() - export_to_webdataset( - test_net_cuts, - output_path=f"{test_net}/shared-%d.tar", - shard_size=300, - ) + #if not os.path.exists(f"{test_net}/shared-0.tar"): + # os.makedirs(test_net) + # test_net_cuts = wenetspeech.test_net_cuts() + # export_to_webdataset( + # test_net_cuts, + # output_path=f"{test_net}/shared-%d.tar", + # shard_size=300, + # ) - if not os.path.exists(f"{test_meeting}/shared-0.tar"): - os.makedirs(test_meeting) - test_meeting_cuts = wenetspeech.test_meeting_cuts() - export_to_webdataset( - test_meeting_cuts, - output_path=f"{test_meeting}/shared-%d.tar", - shard_size=300, - ) + #if not os.path.exists(f"{test_meeting}/shared-0.tar"): + # os.makedirs(test_meeting) + # test_meeting_cuts = wenetspeech.test_meeting_cuts() + # export_to_webdataset( + # test_meeting_cuts, + # output_path=f"{test_meeting}/shared-%d.tar", + # shard_size=300, + # ) - dev_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) - ] - cuts_dev_webdataset = CutSet.from_webdataset( - dev_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) + #dev_shards = [ + # str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) + #] + #cuts_dev_webdataset = CutSet.from_webdataset( + # dev_shards, + # split_by_worker=True, + # split_by_node=True, + # shuffle_shards=True, + #) - test_net_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) - ] - cuts_test_net_webdataset = CutSet.from_webdataset( - test_net_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) + #test_net_shards = [ + # str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) + #] + #cuts_test_net_webdataset = CutSet.from_webdataset( + # test_net_shards, + # split_by_worker=True, + # split_by_node=True, + # shuffle_shards=True, + #) - test_meeting_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar"))) - ] - cuts_test_meeting_webdataset = CutSet.from_webdataset( - test_meeting_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) + #test_meeting_shards = [ + # str(path) + # for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar"))) + #] + #cuts_test_meeting_webdataset = CutSet.from_webdataset( + # test_meeting_shards, + # split_by_worker=True, + # split_by_node=True, + # shuffle_shards=True, + #) - dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset) - test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset) - test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset) + #dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset) + #test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset) + #test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset) - test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_dl = [dev_dl, test_net_dl, test_meeting_dl] + dev_cuts = wenetspeech.valid_cuts() + dev_dl = wenetspeech.valid_dataloaders(dev_cuts) - for test_set, test_dl in zip(test_sets, test_dl): + test_net_cuts = wenetspeech.test_net_cuts() + test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) + + test_meeting_cuts = wenetspeech.test_meeting_cuts() + test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) + + test_car_cuts = wenetspeech.test_car_cuts() + test_car_dl = wenetspeech.test_dataloaders(test_car_cuts) + + # test_sets = ["CAR", "TEST_NET", "DEV", "TEST_MEETING"] + # test_dls = [test_car_dl, test_net_dl, dev_dl, test_meeting_dl] + + test_sets = ["CAR", "TEST_NET", "TEST_MEETING"] + test_dls = [test_car_dl, test_net_dl, test_meeting_dl] + + for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset( dl=test_dl, params=params, model=model, lexicon=lexicon, decoding_graph=decoding_graph, + context_graph=context_graph, ) save_results( params=params, diff --git a/icefall/context_graph.py b/icefall/context_graph.py index 76e7808ad..9f4a26891 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -15,9 +15,12 @@ # limitations under the License. +import logging +import re from dataclasses import dataclass from typing import List import argparse +import k2 import kaldifst import sentencepiece as spm @@ -34,9 +37,43 @@ class ContextGraph: def __init__(self, context_score: float = 1): self.context_score = context_score - def build_context_graph(self, contexts: List[str], sp: spm.SentencePieceProcessor): + def build_context_graph_char(self, contexts: List[str], token_table: k2.SymbolTable): + """Convert a list of texts to a list-of-list of token IDs. + Args: + contexts: + It is a list of strings. + An example containing two strings is given below: + + ['你好中国', '北京欢迎您'] + token_table: + The SymbolTable containing tokens and corresponding ids. + + Returns: + Return a list-of-list of token IDs. + """ + ids: List[List[int]] = [] + whitespace = re.compile(r"([ \t])") + for text in contexts: + text = re.sub(whitespace, "", text) + sub_ids : List[int] = [] + skip = False + for txt in text: + if txt not in token_table: + skip = True + break + sub_ids.append(token_table[txt]) + if skip: + logging.warning(f"Skipping context {text}, as it has OOV char.") + continue + ids.append(sub_ids) + self.build_context_graph(ids) + + def build_context_graph_bpe(self, contexts: List[str], sp: spm.SentencePieceProcessor): contexts_bpe = sp.encode(contexts) + self.build_context_graph(contexts_bpe) + + def build_context_graph(self, token_ids: List[List[int]]): graph = kaldifst.StdVectorFst() start_state = ( graph.add_state() @@ -45,18 +82,18 @@ class ContextGraph: graph.start = 0 # set the start state to 0 graph.set_final(start_state, weight=0) # weight is in log space - for bpe_ids in contexts_bpe: + for tokens in token_ids: prev_state = start_state next_state = start_state backoff_score = 0 - for i in range(len(bpe_ids)): + for i in range(len(tokens)): score = self.context_score - next_state = graph.add_state() if i < len(bpe_ids) - 1 else start_state + next_state = graph.add_state() if i < len(tokens) - 1 else start_state graph.add_arc( state=prev_state, arc=kaldifst.StdArc( - ilabel=bpe_ids[i], - olabel=bpe_ids[i], + ilabel=tokens[i], + olabel=tokens[i], weight=score, nextstate=next_state, ), @@ -105,7 +142,7 @@ if __name__ == "__main__": contexts = ["LOVE CHINA", "HELLO WORLD", "LOVE WORLD"] context_graph = ContextGraph() - context_graph.build_context_graph(contexts, sp) + context_graph.build_context_graph_bpe(contexts, sp) if not is_module_available("graphviz"): raise ValueError("Please 'pip install graphviz' first.")