mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add context biasing for wenetspeech
This commit is contained in:
parent
1f1e28c1ad
commit
ca8ed842f7
@ -913,7 +913,7 @@ def main():
|
|||||||
for line in open(params.context_file).readlines():
|
for line in open(params.context_file).readlines():
|
||||||
contexts.append(line.strip())
|
contexts.append(line.strip())
|
||||||
context_graph = ContextGraph(params.context_score)
|
context_graph = ContextGraph(params.context_score)
|
||||||
context_graph.build_context_graph(contexts, sp)
|
context_graph.build_context_graph_bpe(contexts, sp)
|
||||||
else:
|
else:
|
||||||
context_graph = None
|
context_graph = None
|
||||||
else:
|
else:
|
||||||
@ -935,8 +935,11 @@ def main():
|
|||||||
test_book_cuts = librispeech.test_book_cuts()
|
test_book_cuts = librispeech.test_book_cuts()
|
||||||
test_book_dl = librispeech.test_dataloaders(test_book_cuts)
|
test_book_dl = librispeech.test_dataloaders(test_book_cuts)
|
||||||
|
|
||||||
test_sets = ["test-book", "test-clean", "test-other"]
|
test_book2_cuts = librispeech.test_book2_cuts()
|
||||||
test_dl = [test_book_dl, test_clean_dl, test_other_dl]
|
test_book2_dl = librispeech.test_dataloaders(test_book2_cuts)
|
||||||
|
|
||||||
|
test_sets = ["test-book", "test-book2", "test-clean", "test-other"]
|
||||||
|
test_dl = [test_book_dl, test_book2_dl, test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
|
|||||||
@ -452,6 +452,13 @@ class LibriSpeechAsrDataModule:
|
|||||||
self.args.manifest_dir / "libri_books_feats.jsonl.gz"
|
self.args.manifest_dir / "libri_books_feats.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_book2_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test-books2 cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libri_books2_feats.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_other_cuts(self) -> CutSet:
|
def test_other_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test-other cuts")
|
logging.info("About to get test-other cuts")
|
||||||
|
|||||||
@ -46,8 +46,8 @@ from torch.utils.data import DataLoader
|
|||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
set_caching_enabled(False)
|
# set_caching_enabled(False)
|
||||||
torch.set_num_threads(1)
|
# torch.set_num_threads(1)
|
||||||
|
|
||||||
|
|
||||||
class _SeedWorkers:
|
class _SeedWorkers:
|
||||||
@ -109,7 +109,7 @@ class WenetSpeechAsrDataModule:
|
|||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-buckets",
|
"--num-buckets",
|
||||||
type=int,
|
type=int,
|
||||||
default=300,
|
default=30,
|
||||||
help="The number of buckets for the DynamicBucketingSampler"
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
"(you might want to increase it for larger datasets).",
|
"(you might want to increase it for larger datasets).",
|
||||||
)
|
)
|
||||||
@ -373,7 +373,7 @@ class WenetSpeechAsrDataModule:
|
|||||||
return valid_dl
|
return valid_dl
|
||||||
|
|
||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
logging.debug("About to create test dataset")
|
logging.info("About to create test dataset")
|
||||||
test = K2SpeechRecognitionDataset(
|
test = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
if self.args.on_the_fly_feats
|
if self.args.on_the_fly_feats
|
||||||
@ -383,19 +383,22 @@ class WenetSpeechAsrDataModule:
|
|||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
cuts,
|
cuts,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
rank=0,
|
buffer_size=10000,
|
||||||
world_size=1,
|
# rank=0,
|
||||||
|
# world_size=1,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
|
# from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
|
||||||
|
|
||||||
test_iter_dataset = IterableDatasetWrapper(
|
# test_iter_dataset = IterableDatasetWrapper(
|
||||||
dataset=test,
|
# dataset=test,
|
||||||
sampler=sampler,
|
# sampler=sampler,
|
||||||
)
|
# )
|
||||||
test_dl = DataLoader(
|
test_dl = DataLoader(
|
||||||
test_iter_dataset,
|
# test_iter_dataset,
|
||||||
|
test,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
|
sampler=sampler,
|
||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
)
|
)
|
||||||
return test_dl
|
return test_dl
|
||||||
@ -411,14 +414,19 @@ class WenetSpeechAsrDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def valid_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev cuts")
|
logging.info("About to get dev cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV2.jsonl.gz")
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_net_cuts(self) -> List[CutSet]:
|
def test_net_cuts(self) -> List[CutSet]:
|
||||||
logging.info("About to get TEST_NET cuts")
|
logging.info("About to get TEST_NET cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET2.jsonl.gz")
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_meeting_cuts(self) -> List[CutSet]:
|
def test_meeting_cuts(self) -> List[CutSet]:
|
||||||
logging.info("About to get TEST_MEETING cuts")
|
logging.info("About to get TEST_MEETING cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING2.jsonl.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_car_cuts(self) -> List[CutSet]:
|
||||||
|
logging.info("About to get TEST_CAR cuts")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "car_test_feats.jsonl.gz")
|
||||||
|
|||||||
@ -95,8 +95,10 @@ When training with the L subset, the streaming usage:
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import glob
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -114,6 +116,7 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall import ContextGraph
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -277,6 +280,20 @@ def get_parser():
|
|||||||
help="left context can be seen during decoding (in frames after subsampling)",
|
help="left context can be seen during decoding (in frames after subsampling)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-score",
|
||||||
|
type=float,
|
||||||
|
default=2,
|
||||||
|
help="",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -288,6 +305,7 @@ def decode_one_batch(
|
|||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -325,14 +343,13 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
feature_lens += params.left_context
|
|
||||||
feature = torch.nn.functional.pad(
|
|
||||||
feature,
|
|
||||||
pad=(0, 0, 0, params.left_context),
|
|
||||||
value=LOG_EPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
|
feature_lens += params.left_context
|
||||||
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, params.left_context),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
@ -371,6 +388,7 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
context_graph=context_graph
|
||||||
)
|
)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
@ -410,7 +428,7 @@ def decode_one_batch(
|
|||||||
): hyps
|
): hyps
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}_context_score_{params.context_score}": hyps}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -419,6 +437,7 @@ def decode_dataset(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -463,6 +482,7 @@ def decode_dataset(
|
|||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -551,6 +571,7 @@ def main():
|
|||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += f"-beam-{params.beam_size}"
|
||||||
|
params.suffix += f"-context-score-{params.context_score}"
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -664,13 +685,23 @@ def main():
|
|||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
|
||||||
|
if params.decoding_method == "modified_beam_search":
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
contexts = []
|
||||||
|
for line in open(params.context_file).readlines():
|
||||||
|
contexts.append(line.strip())
|
||||||
|
context_graph = ContextGraph(params.context_score)
|
||||||
|
context_graph.build_context_graph_char(contexts, lexicon.token_table)
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
# Note: Please use "pip install webdataset==0.1.103"
|
# Note: Please use "pip install webdataset==0.1.103"
|
||||||
# for installing the webdataset.
|
# for installing the webdataset.
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
|
|
||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
from lhotse.dataset.webdataset import export_to_webdataset
|
from lhotse.dataset.webdataset import export_to_webdataset
|
||||||
@ -679,82 +710,98 @@ def main():
|
|||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
wenetspeech = WenetSpeechAsrDataModule(args)
|
||||||
|
|
||||||
dev = "dev"
|
#dev = "dev"
|
||||||
test_net = "test_net"
|
#test_net = "test_net"
|
||||||
test_meeting = "test_meeting"
|
#test_meeting = "test_meeting"
|
||||||
|
|
||||||
if not os.path.exists(f"{dev}/shared-0.tar"):
|
#if not os.path.exists(f"{dev}/shared-0.tar"):
|
||||||
os.makedirs(dev)
|
# os.makedirs(dev)
|
||||||
dev_cuts = wenetspeech.valid_cuts()
|
# dev_cuts = wenetspeech.valid_cuts()
|
||||||
export_to_webdataset(
|
# export_to_webdataset(
|
||||||
dev_cuts,
|
# dev_cuts,
|
||||||
output_path=f"{dev}/shared-%d.tar",
|
# output_path=f"{dev}/shared-%d.tar",
|
||||||
shard_size=300,
|
# shard_size=300,
|
||||||
)
|
# )
|
||||||
|
|
||||||
if not os.path.exists(f"{test_net}/shared-0.tar"):
|
#if not os.path.exists(f"{test_net}/shared-0.tar"):
|
||||||
os.makedirs(test_net)
|
# os.makedirs(test_net)
|
||||||
test_net_cuts = wenetspeech.test_net_cuts()
|
# test_net_cuts = wenetspeech.test_net_cuts()
|
||||||
export_to_webdataset(
|
# export_to_webdataset(
|
||||||
test_net_cuts,
|
# test_net_cuts,
|
||||||
output_path=f"{test_net}/shared-%d.tar",
|
# output_path=f"{test_net}/shared-%d.tar",
|
||||||
shard_size=300,
|
# shard_size=300,
|
||||||
)
|
# )
|
||||||
|
|
||||||
if not os.path.exists(f"{test_meeting}/shared-0.tar"):
|
#if not os.path.exists(f"{test_meeting}/shared-0.tar"):
|
||||||
os.makedirs(test_meeting)
|
# os.makedirs(test_meeting)
|
||||||
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
# test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
||||||
export_to_webdataset(
|
# export_to_webdataset(
|
||||||
test_meeting_cuts,
|
# test_meeting_cuts,
|
||||||
output_path=f"{test_meeting}/shared-%d.tar",
|
# output_path=f"{test_meeting}/shared-%d.tar",
|
||||||
shard_size=300,
|
# shard_size=300,
|
||||||
)
|
# )
|
||||||
|
|
||||||
dev_shards = [
|
#dev_shards = [
|
||||||
str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
|
# str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
|
||||||
]
|
#]
|
||||||
cuts_dev_webdataset = CutSet.from_webdataset(
|
#cuts_dev_webdataset = CutSet.from_webdataset(
|
||||||
dev_shards,
|
# dev_shards,
|
||||||
split_by_worker=True,
|
# split_by_worker=True,
|
||||||
split_by_node=True,
|
# split_by_node=True,
|
||||||
shuffle_shards=True,
|
# shuffle_shards=True,
|
||||||
)
|
#)
|
||||||
|
|
||||||
test_net_shards = [
|
#test_net_shards = [
|
||||||
str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
|
# str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
|
||||||
]
|
#]
|
||||||
cuts_test_net_webdataset = CutSet.from_webdataset(
|
#cuts_test_net_webdataset = CutSet.from_webdataset(
|
||||||
test_net_shards,
|
# test_net_shards,
|
||||||
split_by_worker=True,
|
# split_by_worker=True,
|
||||||
split_by_node=True,
|
# split_by_node=True,
|
||||||
shuffle_shards=True,
|
# shuffle_shards=True,
|
||||||
)
|
#)
|
||||||
|
|
||||||
test_meeting_shards = [
|
#test_meeting_shards = [
|
||||||
str(path)
|
# str(path)
|
||||||
for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
|
# for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
|
||||||
]
|
#]
|
||||||
cuts_test_meeting_webdataset = CutSet.from_webdataset(
|
#cuts_test_meeting_webdataset = CutSet.from_webdataset(
|
||||||
test_meeting_shards,
|
# test_meeting_shards,
|
||||||
split_by_worker=True,
|
# split_by_worker=True,
|
||||||
split_by_node=True,
|
# split_by_node=True,
|
||||||
shuffle_shards=True,
|
# shuffle_shards=True,
|
||||||
)
|
#)
|
||||||
|
|
||||||
dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
|
#dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
|
||||||
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
|
#test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
|
||||||
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
|
#test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
|
||||||
|
|
||||||
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
dev_cuts = wenetspeech.valid_cuts()
|
||||||
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
|
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
test_net_cuts = wenetspeech.test_net_cuts()
|
||||||
|
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
|
||||||
|
|
||||||
|
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
||||||
|
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
||||||
|
|
||||||
|
test_car_cuts = wenetspeech.test_car_cuts()
|
||||||
|
test_car_dl = wenetspeech.test_dataloaders(test_car_cuts)
|
||||||
|
|
||||||
|
# test_sets = ["CAR", "TEST_NET", "DEV", "TEST_MEETING"]
|
||||||
|
# test_dls = [test_car_dl, test_net_dl, dev_dl, test_meeting_dl]
|
||||||
|
|
||||||
|
test_sets = ["CAR", "TEST_NET", "TEST_MEETING"]
|
||||||
|
test_dls = [test_car_dl, test_net_dl, test_meeting_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
save_results(
|
save_results(
|
||||||
params=params,
|
params=params,
|
||||||
|
|||||||
@ -15,9 +15,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List
|
from typing import List
|
||||||
import argparse
|
import argparse
|
||||||
|
import k2
|
||||||
import kaldifst
|
import kaldifst
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
|
||||||
@ -34,9 +37,43 @@ class ContextGraph:
|
|||||||
def __init__(self, context_score: float = 1):
|
def __init__(self, context_score: float = 1):
|
||||||
self.context_score = context_score
|
self.context_score = context_score
|
||||||
|
|
||||||
def build_context_graph(self, contexts: List[str], sp: spm.SentencePieceProcessor):
|
def build_context_graph_char(self, contexts: List[str], token_table: k2.SymbolTable):
|
||||||
|
"""Convert a list of texts to a list-of-list of token IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts:
|
||||||
|
It is a list of strings.
|
||||||
|
An example containing two strings is given below:
|
||||||
|
|
||||||
|
['你好中国', '北京欢迎您']
|
||||||
|
token_table:
|
||||||
|
The SymbolTable containing tokens and corresponding ids.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list-of-list of token IDs.
|
||||||
|
"""
|
||||||
|
ids: List[List[int]] = []
|
||||||
|
whitespace = re.compile(r"([ \t])")
|
||||||
|
for text in contexts:
|
||||||
|
text = re.sub(whitespace, "", text)
|
||||||
|
sub_ids : List[int] = []
|
||||||
|
skip = False
|
||||||
|
for txt in text:
|
||||||
|
if txt not in token_table:
|
||||||
|
skip = True
|
||||||
|
break
|
||||||
|
sub_ids.append(token_table[txt])
|
||||||
|
if skip:
|
||||||
|
logging.warning(f"Skipping context {text}, as it has OOV char.")
|
||||||
|
continue
|
||||||
|
ids.append(sub_ids)
|
||||||
|
self.build_context_graph(ids)
|
||||||
|
|
||||||
|
def build_context_graph_bpe(self, contexts: List[str], sp: spm.SentencePieceProcessor):
|
||||||
contexts_bpe = sp.encode(contexts)
|
contexts_bpe = sp.encode(contexts)
|
||||||
|
self.build_context_graph(contexts_bpe)
|
||||||
|
|
||||||
|
def build_context_graph(self, token_ids: List[List[int]]):
|
||||||
graph = kaldifst.StdVectorFst()
|
graph = kaldifst.StdVectorFst()
|
||||||
start_state = (
|
start_state = (
|
||||||
graph.add_state()
|
graph.add_state()
|
||||||
@ -45,18 +82,18 @@ class ContextGraph:
|
|||||||
graph.start = 0 # set the start state to 0
|
graph.start = 0 # set the start state to 0
|
||||||
graph.set_final(start_state, weight=0) # weight is in log space
|
graph.set_final(start_state, weight=0) # weight is in log space
|
||||||
|
|
||||||
for bpe_ids in contexts_bpe:
|
for tokens in token_ids:
|
||||||
prev_state = start_state
|
prev_state = start_state
|
||||||
next_state = start_state
|
next_state = start_state
|
||||||
backoff_score = 0
|
backoff_score = 0
|
||||||
for i in range(len(bpe_ids)):
|
for i in range(len(tokens)):
|
||||||
score = self.context_score
|
score = self.context_score
|
||||||
next_state = graph.add_state() if i < len(bpe_ids) - 1 else start_state
|
next_state = graph.add_state() if i < len(tokens) - 1 else start_state
|
||||||
graph.add_arc(
|
graph.add_arc(
|
||||||
state=prev_state,
|
state=prev_state,
|
||||||
arc=kaldifst.StdArc(
|
arc=kaldifst.StdArc(
|
||||||
ilabel=bpe_ids[i],
|
ilabel=tokens[i],
|
||||||
olabel=bpe_ids[i],
|
olabel=tokens[i],
|
||||||
weight=score,
|
weight=score,
|
||||||
nextstate=next_state,
|
nextstate=next_state,
|
||||||
),
|
),
|
||||||
@ -105,7 +142,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
contexts = ["LOVE CHINA", "HELLO WORLD", "LOVE WORLD"]
|
contexts = ["LOVE CHINA", "HELLO WORLD", "LOVE WORLD"]
|
||||||
context_graph = ContextGraph()
|
context_graph = ContextGraph()
|
||||||
context_graph.build_context_graph(contexts, sp)
|
context_graph.build_context_graph_bpe(contexts, sp)
|
||||||
|
|
||||||
if not is_module_available("graphviz"):
|
if not is_module_available("graphviz"):
|
||||||
raise ValueError("Please 'pip install graphviz' first.")
|
raise ValueError("Please 'pip install graphviz' first.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user