Add context biasing for wenetspeech

This commit is contained in:
pkufool 2023-03-22 19:45:45 +08:00
parent 1f1e28c1ad
commit ca8ed842f7
5 changed files with 198 additions and 96 deletions

View File

@ -913,7 +913,7 @@ def main():
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build_context_graph(contexts, sp)
context_graph.build_context_graph_bpe(contexts, sp)
else:
context_graph = None
else:
@ -935,8 +935,11 @@ def main():
test_book_cuts = librispeech.test_book_cuts()
test_book_dl = librispeech.test_dataloaders(test_book_cuts)
test_sets = ["test-book", "test-clean", "test-other"]
test_dl = [test_book_dl, test_clean_dl, test_other_dl]
test_book2_cuts = librispeech.test_book2_cuts()
test_book2_dl = librispeech.test_dataloaders(test_book2_cuts)
test_sets = ["test-book", "test-book2", "test-clean", "test-other"]
test_dl = [test_book_dl, test_book2_dl, test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(

View File

@ -452,6 +452,13 @@ class LibriSpeechAsrDataModule:
self.args.manifest_dir / "libri_books_feats.jsonl.gz"
)
@lru_cache()
def test_book2_cuts(self) -> CutSet:
logging.info("About to get test-books2 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libri_books2_feats.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")

View File

@ -46,8 +46,8 @@ from torch.utils.data import DataLoader
from icefall.utils import str2bool
set_caching_enabled(False)
torch.set_num_threads(1)
# set_caching_enabled(False)
# torch.set_num_threads(1)
class _SeedWorkers:
@ -109,7 +109,7 @@ class WenetSpeechAsrDataModule:
group.add_argument(
"--num-buckets",
type=int,
default=300,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
@ -373,7 +373,7 @@ class WenetSpeechAsrDataModule:
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
logging.info("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
@ -383,19 +383,22 @@ class WenetSpeechAsrDataModule:
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
rank=0,
world_size=1,
buffer_size=10000,
# rank=0,
# world_size=1,
shuffle=False,
)
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
# from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
test_iter_dataset = IterableDatasetWrapper(
dataset=test,
sampler=sampler,
)
# test_iter_dataset = IterableDatasetWrapper(
# dataset=test,
# sampler=sampler,
# )
test_dl = DataLoader(
test_iter_dataset,
# test_iter_dataset,
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@ -411,14 +414,19 @@ class WenetSpeechAsrDataModule:
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV2.jsonl.gz")
@lru_cache()
def test_net_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_NET cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET2.jsonl.gz")
@lru_cache()
def test_meeting_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_MEETING cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING2.jsonl.gz")
@lru_cache()
def test_car_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_CAR cuts")
return load_manifest_lazy(self.args.manifest_dir / "car_test_feats.jsonl.gz")

View File

@ -95,8 +95,10 @@ When training with the L subset, the streaming usage:
import argparse
import glob
import logging
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -114,6 +116,7 @@ from beam_search import (
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall import ContextGraph
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -277,6 +280,20 @@ def get_parser():
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--context-score",
type=float,
default=2,
help="",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="",
)
add_model_arguments(parser)
return parser
@ -288,6 +305,7 @@ def decode_one_batch(
lexicon: Lexicon,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -325,14 +343,13 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
@ -371,6 +388,7 @@ def decode_one_batch(
encoder_out=encoder_out,
beam=params.beam_size,
encoder_out_lens=encoder_out_lens,
context_graph=context_graph
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -410,7 +428,7 @@ def decode_one_batch(
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
return {f"beam_size_{params.beam_size}_context_score_{params.context_score}": hyps}
def decode_dataset(
@ -419,6 +437,7 @@ def decode_dataset(
model: nn.Module,
lexicon: Lexicon,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -463,6 +482,7 @@ def decode_dataset(
lexicon=lexicon,
decoding_graph=decoding_graph,
batch=batch,
context_graph=context_graph,
)
for name, hyps in hyps_dict.items():
@ -551,6 +571,7 @@ def main():
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -664,13 +685,23 @@ def main():
else:
decoding_graph = None
if params.decoding_method == "modified_beam_search":
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build_context_graph_char(contexts, lexicon.token_table)
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# Note: Please use "pip install webdataset==0.1.103"
# for installing the webdataset.
import glob
import os
from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset
@ -679,82 +710,98 @@ def main():
args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args)
dev = "dev"
test_net = "test_net"
test_meeting = "test_meeting"
#dev = "dev"
#test_net = "test_net"
#test_meeting = "test_meeting"
#if not os.path.exists(f"{dev}/shared-0.tar"):
# os.makedirs(dev)
# dev_cuts = wenetspeech.valid_cuts()
# export_to_webdataset(
# dev_cuts,
# output_path=f"{dev}/shared-%d.tar",
# shard_size=300,
# )
#if not os.path.exists(f"{test_net}/shared-0.tar"):
# os.makedirs(test_net)
# test_net_cuts = wenetspeech.test_net_cuts()
# export_to_webdataset(
# test_net_cuts,
# output_path=f"{test_net}/shared-%d.tar",
# shard_size=300,
# )
#if not os.path.exists(f"{test_meeting}/shared-0.tar"):
# os.makedirs(test_meeting)
# test_meeting_cuts = wenetspeech.test_meeting_cuts()
# export_to_webdataset(
# test_meeting_cuts,
# output_path=f"{test_meeting}/shared-%d.tar",
# shard_size=300,
# )
#dev_shards = [
# str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
#]
#cuts_dev_webdataset = CutSet.from_webdataset(
# dev_shards,
# split_by_worker=True,
# split_by_node=True,
# shuffle_shards=True,
#)
#test_net_shards = [
# str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
#]
#cuts_test_net_webdataset = CutSet.from_webdataset(
# test_net_shards,
# split_by_worker=True,
# split_by_node=True,
# shuffle_shards=True,
#)
#test_meeting_shards = [
# str(path)
# for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
#]
#cuts_test_meeting_webdataset = CutSet.from_webdataset(
# test_meeting_shards,
# split_by_worker=True,
# split_by_node=True,
# shuffle_shards=True,
#)
#dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
#test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
#test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
if not os.path.exists(f"{dev}/shared-0.tar"):
os.makedirs(dev)
dev_cuts = wenetspeech.valid_cuts()
export_to_webdataset(
dev_cuts,
output_path=f"{dev}/shared-%d.tar",
shard_size=300,
)
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
if not os.path.exists(f"{test_net}/shared-0.tar"):
os.makedirs(test_net)
test_net_cuts = wenetspeech.test_net_cuts()
export_to_webdataset(
test_net_cuts,
output_path=f"{test_net}/shared-%d.tar",
shard_size=300,
)
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
if not os.path.exists(f"{test_meeting}/shared-0.tar"):
os.makedirs(test_meeting)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
export_to_webdataset(
test_meeting_cuts,
output_path=f"{test_meeting}/shared-%d.tar",
shard_size=300,
)
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
dev_shards = [
str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
]
cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_car_cuts = wenetspeech.test_car_cuts()
test_car_dl = wenetspeech.test_dataloaders(test_car_cuts)
test_net_shards = [
str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
]
cuts_test_net_webdataset = CutSet.from_webdataset(
test_net_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
# test_sets = ["CAR", "TEST_NET", "DEV", "TEST_MEETING"]
# test_dls = [test_car_dl, test_net_dl, dev_dl, test_meeting_dl]
test_meeting_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
]
cuts_test_meeting_webdataset = CutSet.from_webdataset(
test_meeting_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_sets = ["CAR", "TEST_NET", "TEST_MEETING"]
test_dls = [test_car_dl, test_net_dl, test_meeting_dl]
dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dl):
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
lexicon=lexicon,
decoding_graph=decoding_graph,
context_graph=context_graph,
)
save_results(
params=params,

View File

@ -15,9 +15,12 @@
# limitations under the License.
import logging
import re
from dataclasses import dataclass
from typing import List
import argparse
import k2
import kaldifst
import sentencepiece as spm
@ -34,9 +37,43 @@ class ContextGraph:
def __init__(self, context_score: float = 1):
self.context_score = context_score
def build_context_graph(self, contexts: List[str], sp: spm.SentencePieceProcessor):
def build_context_graph_char(self, contexts: List[str], token_table: k2.SymbolTable):
"""Convert a list of texts to a list-of-list of token IDs.
Args:
contexts:
It is a list of strings.
An example containing two strings is given below:
['你好中国', '北京欢迎您']
token_table:
The SymbolTable containing tokens and corresponding ids.
Returns:
Return a list-of-list of token IDs.
"""
ids: List[List[int]] = []
whitespace = re.compile(r"([ \t])")
for text in contexts:
text = re.sub(whitespace, "", text)
sub_ids : List[int] = []
skip = False
for txt in text:
if txt not in token_table:
skip = True
break
sub_ids.append(token_table[txt])
if skip:
logging.warning(f"Skipping context {text}, as it has OOV char.")
continue
ids.append(sub_ids)
self.build_context_graph(ids)
def build_context_graph_bpe(self, contexts: List[str], sp: spm.SentencePieceProcessor):
contexts_bpe = sp.encode(contexts)
self.build_context_graph(contexts_bpe)
def build_context_graph(self, token_ids: List[List[int]]):
graph = kaldifst.StdVectorFst()
start_state = (
graph.add_state()
@ -45,18 +82,18 @@ class ContextGraph:
graph.start = 0 # set the start state to 0
graph.set_final(start_state, weight=0) # weight is in log space
for bpe_ids in contexts_bpe:
for tokens in token_ids:
prev_state = start_state
next_state = start_state
backoff_score = 0
for i in range(len(bpe_ids)):
for i in range(len(tokens)):
score = self.context_score
next_state = graph.add_state() if i < len(bpe_ids) - 1 else start_state
next_state = graph.add_state() if i < len(tokens) - 1 else start_state
graph.add_arc(
state=prev_state,
arc=kaldifst.StdArc(
ilabel=bpe_ids[i],
olabel=bpe_ids[i],
ilabel=tokens[i],
olabel=tokens[i],
weight=score,
nextstate=next_state,
),
@ -105,7 +142,7 @@ if __name__ == "__main__":
contexts = ["LOVE CHINA", "HELLO WORLD", "LOVE WORLD"]
context_graph = ContextGraph()
context_graph.build_context_graph(contexts, sp)
context_graph.build_context_graph_bpe(contexts, sp)
if not is_module_available("graphviz"):
raise ValueError("Please 'pip install graphviz' first.")