Add context biasing for wenetspeech

This commit is contained in:
pkufool 2023-03-22 19:45:45 +08:00
parent 1f1e28c1ad
commit ca8ed842f7
5 changed files with 198 additions and 96 deletions

View File

@ -913,7 +913,7 @@ def main():
for line in open(params.context_file).readlines(): for line in open(params.context_file).readlines():
contexts.append(line.strip()) contexts.append(line.strip())
context_graph = ContextGraph(params.context_score) context_graph = ContextGraph(params.context_score)
context_graph.build_context_graph(contexts, sp) context_graph.build_context_graph_bpe(contexts, sp)
else: else:
context_graph = None context_graph = None
else: else:
@ -935,8 +935,11 @@ def main():
test_book_cuts = librispeech.test_book_cuts() test_book_cuts = librispeech.test_book_cuts()
test_book_dl = librispeech.test_dataloaders(test_book_cuts) test_book_dl = librispeech.test_dataloaders(test_book_cuts)
test_sets = ["test-book", "test-clean", "test-other"] test_book2_cuts = librispeech.test_book2_cuts()
test_dl = [test_book_dl, test_clean_dl, test_other_dl] test_book2_dl = librispeech.test_dataloaders(test_book2_cuts)
test_sets = ["test-book", "test-book2", "test-clean", "test-other"]
test_dl = [test_book_dl, test_book2_dl, test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(

View File

@ -452,6 +452,13 @@ class LibriSpeechAsrDataModule:
self.args.manifest_dir / "libri_books_feats.jsonl.gz" self.args.manifest_dir / "libri_books_feats.jsonl.gz"
) )
@lru_cache()
def test_book2_cuts(self) -> CutSet:
logging.info("About to get test-books2 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libri_books2_feats.jsonl.gz"
)
@lru_cache() @lru_cache()
def test_other_cuts(self) -> CutSet: def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts") logging.info("About to get test-other cuts")

View File

@ -46,8 +46,8 @@ from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
set_caching_enabled(False) # set_caching_enabled(False)
torch.set_num_threads(1) # torch.set_num_threads(1)
class _SeedWorkers: class _SeedWorkers:
@ -109,7 +109,7 @@ class WenetSpeechAsrDataModule:
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=300, default=30,
help="The number of buckets for the DynamicBucketingSampler" help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).", "(you might want to increase it for larger datasets).",
) )
@ -373,7 +373,7 @@ class WenetSpeechAsrDataModule:
return valid_dl return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.info("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats if self.args.on_the_fly_feats
@ -383,19 +383,22 @@ class WenetSpeechAsrDataModule:
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(
cuts, cuts,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
rank=0, buffer_size=10000,
world_size=1, # rank=0,
# world_size=1,
shuffle=False, shuffle=False,
) )
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper # from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
test_iter_dataset = IterableDatasetWrapper( # test_iter_dataset = IterableDatasetWrapper(
dataset=test, # dataset=test,
sampler=sampler, # sampler=sampler,
) # )
test_dl = DataLoader( test_dl = DataLoader(
test_iter_dataset, # test_iter_dataset,
test,
batch_size=None, batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
) )
return test_dl return test_dl
@ -411,14 +414,19 @@ class WenetSpeechAsrDataModule:
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV2.jsonl.gz")
@lru_cache() @lru_cache()
def test_net_cuts(self) -> List[CutSet]: def test_net_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_NET cuts") logging.info("About to get TEST_NET cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz") return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET2.jsonl.gz")
@lru_cache() @lru_cache()
def test_meeting_cuts(self) -> List[CutSet]: def test_meeting_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_MEETING cuts") logging.info("About to get TEST_MEETING cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz") return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING2.jsonl.gz")
@lru_cache()
def test_car_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_CAR cuts")
return load_manifest_lazy(self.args.manifest_dir / "car_test_feats.jsonl.gz")

View File

@ -95,8 +95,10 @@ When training with the L subset, the streaming usage:
import argparse import argparse
import glob
import logging import logging
import math import math
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -114,6 +116,7 @@ from beam_search import (
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall import ContextGraph
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -277,6 +280,20 @@ def get_parser():
help="left context can be seen during decoding (in frames after subsampling)", help="left context can be seen during decoding (in frames after subsampling)",
) )
parser.add_argument(
"--context-score",
type=float,
default=2,
help="",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -288,6 +305,7 @@ def decode_one_batch(
lexicon: Lexicon, lexicon: Lexicon,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -325,14 +343,13 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
feature_lens += params.left_context feature_lens += params.left_context
feature = torch.nn.functional.pad( feature = torch.nn.functional.pad(
feature, feature,
pad=(0, 0, 0, params.left_context), pad=(0, 0, 0, params.left_context),
value=LOG_EPS, value=LOG_EPS,
) )
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
@ -371,6 +388,7 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
beam=params.beam_size, beam=params.beam_size,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
context_graph=context_graph
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -410,7 +428,7 @@ def decode_one_batch(
): hyps ): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}_context_score_{params.context_score}": hyps}
def decode_dataset( def decode_dataset(
@ -419,6 +437,7 @@ def decode_dataset(
model: nn.Module, model: nn.Module,
lexicon: Lexicon, lexicon: Lexicon,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -463,6 +482,7 @@ def decode_dataset(
lexicon=lexicon, lexicon=lexicon,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
batch=batch, batch=batch,
context_graph=context_graph,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -551,6 +571,7 @@ def main():
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
params.suffix += f"-context-score-{params.context_score}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -664,13 +685,23 @@ def main():
else: else:
decoding_graph = None decoding_graph = None
if params.decoding_method == "modified_beam_search":
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build_context_graph_char(contexts, lexicon.token_table)
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# Note: Please use "pip install webdataset==0.1.103" # Note: Please use "pip install webdataset==0.1.103"
# for installing the webdataset. # for installing the webdataset.
import glob
import os
from lhotse import CutSet from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset from lhotse.dataset.webdataset import export_to_webdataset
@ -679,82 +710,98 @@ def main():
args.return_cuts = True args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args) wenetspeech = WenetSpeechAsrDataModule(args)
dev = "dev" #dev = "dev"
test_net = "test_net" #test_net = "test_net"
test_meeting = "test_meeting" #test_meeting = "test_meeting"
#if not os.path.exists(f"{dev}/shared-0.tar"):
# os.makedirs(dev)
# dev_cuts = wenetspeech.valid_cuts()
# export_to_webdataset(
# dev_cuts,
# output_path=f"{dev}/shared-%d.tar",
# shard_size=300,
# )
#if not os.path.exists(f"{test_net}/shared-0.tar"):
# os.makedirs(test_net)
# test_net_cuts = wenetspeech.test_net_cuts()
# export_to_webdataset(
# test_net_cuts,
# output_path=f"{test_net}/shared-%d.tar",
# shard_size=300,
# )
#if not os.path.exists(f"{test_meeting}/shared-0.tar"):
# os.makedirs(test_meeting)
# test_meeting_cuts = wenetspeech.test_meeting_cuts()
# export_to_webdataset(
# test_meeting_cuts,
# output_path=f"{test_meeting}/shared-%d.tar",
# shard_size=300,
# )
#dev_shards = [
# str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
#]
#cuts_dev_webdataset = CutSet.from_webdataset(
# dev_shards,
# split_by_worker=True,
# split_by_node=True,
# shuffle_shards=True,
#)
#test_net_shards = [
# str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
#]
#cuts_test_net_webdataset = CutSet.from_webdataset(
# test_net_shards,
# split_by_worker=True,
# split_by_node=True,
# shuffle_shards=True,
#)
#test_meeting_shards = [
# str(path)
# for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
#]
#cuts_test_meeting_webdataset = CutSet.from_webdataset(
# test_meeting_shards,
# split_by_worker=True,
# split_by_node=True,
# shuffle_shards=True,
#)
#dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
#test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
#test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
if not os.path.exists(f"{dev}/shared-0.tar"):
os.makedirs(dev)
dev_cuts = wenetspeech.valid_cuts() dev_cuts = wenetspeech.valid_cuts()
export_to_webdataset( dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
dev_cuts,
output_path=f"{dev}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_net}/shared-0.tar"):
os.makedirs(test_net)
test_net_cuts = wenetspeech.test_net_cuts() test_net_cuts = wenetspeech.test_net_cuts()
export_to_webdataset( test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
test_net_cuts,
output_path=f"{test_net}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_meeting}/shared-0.tar"):
os.makedirs(test_meeting)
test_meeting_cuts = wenetspeech.test_meeting_cuts() test_meeting_cuts = wenetspeech.test_meeting_cuts()
export_to_webdataset( test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_meeting_cuts,
output_path=f"{test_meeting}/shared-%d.tar",
shard_size=300,
)
dev_shards = [ test_car_cuts = wenetspeech.test_car_cuts()
str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) test_car_dl = wenetspeech.test_dataloaders(test_car_cuts)
]
cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_net_shards = [ # test_sets = ["CAR", "TEST_NET", "DEV", "TEST_MEETING"]
str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) # test_dls = [test_car_dl, test_net_dl, dev_dl, test_meeting_dl]
]
cuts_test_net_webdataset = CutSet.from_webdataset(
test_net_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_meeting_shards = [ test_sets = ["CAR", "TEST_NET", "TEST_MEETING"]
str(path) test_dls = [test_car_dl, test_net_dl, test_meeting_dl]
for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
]
cuts_test_meeting_webdataset = CutSet.from_webdataset(
test_meeting_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset) for test_set, test_dl in zip(test_sets, test_dls):
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
lexicon=lexicon, lexicon=lexicon,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
) )
save_results( save_results(
params=params, params=params,

View File

@ -15,9 +15,12 @@
# limitations under the License. # limitations under the License.
import logging
import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
import argparse import argparse
import k2
import kaldifst import kaldifst
import sentencepiece as spm import sentencepiece as spm
@ -34,9 +37,43 @@ class ContextGraph:
def __init__(self, context_score: float = 1): def __init__(self, context_score: float = 1):
self.context_score = context_score self.context_score = context_score
def build_context_graph(self, contexts: List[str], sp: spm.SentencePieceProcessor): def build_context_graph_char(self, contexts: List[str], token_table: k2.SymbolTable):
"""Convert a list of texts to a list-of-list of token IDs.
Args:
contexts:
It is a list of strings.
An example containing two strings is given below:
['你好中国', '北京欢迎您']
token_table:
The SymbolTable containing tokens and corresponding ids.
Returns:
Return a list-of-list of token IDs.
"""
ids: List[List[int]] = []
whitespace = re.compile(r"([ \t])")
for text in contexts:
text = re.sub(whitespace, "", text)
sub_ids : List[int] = []
skip = False
for txt in text:
if txt not in token_table:
skip = True
break
sub_ids.append(token_table[txt])
if skip:
logging.warning(f"Skipping context {text}, as it has OOV char.")
continue
ids.append(sub_ids)
self.build_context_graph(ids)
def build_context_graph_bpe(self, contexts: List[str], sp: spm.SentencePieceProcessor):
contexts_bpe = sp.encode(contexts) contexts_bpe = sp.encode(contexts)
self.build_context_graph(contexts_bpe)
def build_context_graph(self, token_ids: List[List[int]]):
graph = kaldifst.StdVectorFst() graph = kaldifst.StdVectorFst()
start_state = ( start_state = (
graph.add_state() graph.add_state()
@ -45,18 +82,18 @@ class ContextGraph:
graph.start = 0 # set the start state to 0 graph.start = 0 # set the start state to 0
graph.set_final(start_state, weight=0) # weight is in log space graph.set_final(start_state, weight=0) # weight is in log space
for bpe_ids in contexts_bpe: for tokens in token_ids:
prev_state = start_state prev_state = start_state
next_state = start_state next_state = start_state
backoff_score = 0 backoff_score = 0
for i in range(len(bpe_ids)): for i in range(len(tokens)):
score = self.context_score score = self.context_score
next_state = graph.add_state() if i < len(bpe_ids) - 1 else start_state next_state = graph.add_state() if i < len(tokens) - 1 else start_state
graph.add_arc( graph.add_arc(
state=prev_state, state=prev_state,
arc=kaldifst.StdArc( arc=kaldifst.StdArc(
ilabel=bpe_ids[i], ilabel=tokens[i],
olabel=bpe_ids[i], olabel=tokens[i],
weight=score, weight=score,
nextstate=next_state, nextstate=next_state,
), ),
@ -105,7 +142,7 @@ if __name__ == "__main__":
contexts = ["LOVE CHINA", "HELLO WORLD", "LOVE WORLD"] contexts = ["LOVE CHINA", "HELLO WORLD", "LOVE WORLD"]
context_graph = ContextGraph() context_graph = ContextGraph()
context_graph.build_context_graph(contexts, sp) context_graph.build_context_graph_bpe(contexts, sp)
if not is_module_available("graphviz"): if not is_module_available("graphviz"):
raise ValueError("Please 'pip install graphviz' first.") raise ValueError("Please 'pip install graphviz' first.")