mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add Context biasing (#1038)
* Add context biasing for librispeech * Add context biasing for wenetspeech * fix bugs * Implement Aho-Corasick context graph * fix some bugs * Fixes to forward_one_step; add draw to context graph * add output arc; fix black * Fix wenetspeech tokenizer * Minor fixes to the decode.py
This commit is contained in:
parent
ca60ced213
commit
ba257efbcd
@ -58,6 +58,7 @@ Usage:
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -76,6 +77,8 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -211,6 +214,26 @@ def get_parser():
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -222,6 +245,7 @@ def decode_one_batch(
|
||||
token_table: k2.SymbolTable,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -285,6 +309,7 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
else:
|
||||
hyp_tokens = []
|
||||
@ -324,7 +349,12 @@ def decode_one_batch(
|
||||
): hyps
|
||||
}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
key += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
key += "-no-context-words"
|
||||
return {key: hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -333,6 +363,7 @@ def decode_dataset(
|
||||
model: nn.Module,
|
||||
token_table: k2.SymbolTable,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -377,6 +408,7 @@ def decode_dataset(
|
||||
model=model,
|
||||
token_table=token_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
@ -407,16 +439,17 @@ def save_results(
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
|
||||
store_transcripts(filename=recog_path, texts=results_char)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -457,6 +490,12 @@ def main():
|
||||
"fast_beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
|
||||
if os.path.exists(params.context_file):
|
||||
params.has_contexts = True
|
||||
else:
|
||||
params.has_contexts = False
|
||||
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
@ -470,6 +509,10 @@ def main():
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += "-no-contexts-words"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -490,6 +533,11 @@ def main():
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
@ -586,6 +634,19 @@ def main():
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
if params.decoding_method == "modified_beam_search":
|
||||
if os.path.exists(params.context_file):
|
||||
contexts_text = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts_text.append(line.strip())
|
||||
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(contexts)
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
context_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -608,6 +669,7 @@ def main():
|
||||
model=model,
|
||||
token_table=lexicon.token_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
@ -24,7 +24,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
from model import Transducer
|
||||
|
||||
from icefall import NgramLm, NgramLmStateCost
|
||||
from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
from icefall.lm_wrapper import LmScorer
|
||||
from icefall.rnn_lm.model import RnnLmModel
|
||||
@ -765,6 +765,9 @@ class Hypothesis:
|
||||
# N-gram LM state
|
||||
state_cost: Optional[NgramLmStateCost] = None
|
||||
|
||||
# Context graph state
|
||||
context_state: Optional[ContextState] = None
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Return a string representation of self.ys"""
|
||||
@ -917,6 +920,7 @@ def modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
@ -968,6 +972,7 @@ def modified_beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
context_state=None if context_graph is None else context_graph.root,
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
@ -990,6 +995,7 @@ def modified_beam_search(
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
@ -1047,21 +1053,51 @@ def modified_beam_search(
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
context_score = 0
|
||||
new_context_state = None if context_graph is None else hyp.context_state
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
if context_graph is not None:
|
||||
(
|
||||
context_score,
|
||||
new_context_state,
|
||||
) = context_graph.forward_one_step(hyp.context_state, new_token)
|
||||
|
||||
new_log_prob = topk_log_probs[k] + context_score
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||
ys=new_ys,
|
||||
log_prob=new_log_prob,
|
||||
timestamp=new_timestamp,
|
||||
context_state=new_context_state,
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
B = B + finalized_B
|
||||
|
||||
# finalize context_state, if the matched contexts do not reach final state
|
||||
# we need to add the score on the corresponding backoff arc
|
||||
if context_graph is not None:
|
||||
finalized_B = [HypothesisList() for _ in range(len(B))]
|
||||
for i, hyps in enumerate(B):
|
||||
for hyp in list(hyps):
|
||||
context_score, new_context_state = context_graph.finalize(
|
||||
hyp.context_state
|
||||
)
|
||||
finalized_B[i].add(
|
||||
Hypothesis(
|
||||
ys=hyp.ys,
|
||||
log_prob=hyp.log_prob + context_score,
|
||||
timestamp=hyp.timestamp,
|
||||
context_state=new_context_state,
|
||||
)
|
||||
)
|
||||
B = finalized_B
|
||||
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
|
||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||
|
@ -125,6 +125,7 @@ For example:
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -146,6 +147,7 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import ContextGraph
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -353,6 +355,27 @@ def get_parser():
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -365,6 +388,7 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -494,6 +518,7 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
context_graph=context_graph,
|
||||
return_timestamps=True,
|
||||
)
|
||||
else:
|
||||
@ -548,7 +573,12 @@ def decode_one_batch(
|
||||
|
||||
return {key: (hyps, timestamps)}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
key += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
key += "-no-context-words"
|
||||
return {key: (hyps, timestamps)}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -558,6 +588,7 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -622,6 +653,7 @@ def decode_dataset(
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
|
||||
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||
@ -728,6 +760,12 @@ def main():
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
|
||||
if os.path.exists(params.context_file):
|
||||
params.has_contexts = True
|
||||
else:
|
||||
params.has_contexts = False
|
||||
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
@ -750,6 +788,10 @@ def main():
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += "-no-context-words"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -881,6 +923,18 @@ def main():
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
if params.decoding_method == "modified_beam_search":
|
||||
if os.path.exists(params.context_file):
|
||||
contexts = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts.append(line.strip())
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(sp.encode(contexts))
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
context_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -905,6 +959,7 @@ def main():
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
@ -106,7 +106,7 @@ class WenetSpeechAsrDataModule:
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=300,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
@ -364,7 +364,7 @@ class WenetSpeechAsrDataModule:
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
logging.info("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
|
@ -92,7 +92,7 @@ When training with the L subset, the streaming usage:
|
||||
--causal-convolution 1 \
|
||||
--decode-chunk-size 16 \
|
||||
--left-context 64
|
||||
|
||||
|
||||
(4) modified beam search with RNNLM shallow fusion
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 35 \
|
||||
@ -112,8 +112,10 @@ When training with the L subset, the streaming usage:
|
||||
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -133,7 +135,8 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import LmScorer, NgramLm
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -307,6 +310,26 @@ def get_parser():
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-shallow-fusion",
|
||||
type=str2bool,
|
||||
@ -362,6 +385,7 @@ def decode_one_batch(
|
||||
lexicon: Lexicon,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
LM: Optional[LmScorer] = None,
|
||||
@ -402,14 +426,13 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
@ -448,6 +471,7 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
beam=params.beam_size,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
@ -509,7 +533,12 @@ def decode_one_batch(
|
||||
): hyps
|
||||
}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
key += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
key += "-no-context-words"
|
||||
return {key: hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -518,6 +547,7 @@ def decode_dataset(
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
LM: Optional[LmScorer] = None,
|
||||
@ -567,6 +597,7 @@ def decode_dataset(
|
||||
lexicon=lexicon,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
context_graph=context_graph,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
@ -646,6 +677,12 @@ def main():
|
||||
"modified_beam_search_lm_shallow_fusion",
|
||||
"modified_beam_search_LODR",
|
||||
)
|
||||
|
||||
if os.path.exists(params.context_file):
|
||||
params.has_contexts = True
|
||||
else:
|
||||
params.has_contexts = False
|
||||
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
@ -655,6 +692,10 @@ def main():
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += "-no-contexts-words"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -684,11 +725,15 @@ def main():
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
@ -816,6 +861,19 @@ def main():
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
if params.decoding_method == "modified_beam_search":
|
||||
if os.path.exists(params.context_file):
|
||||
contexts_text = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts_text.append(line.strip())
|
||||
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(contexts)
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
context_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -833,15 +891,16 @@ def main():
|
||||
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
||||
|
||||
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
||||
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
|
||||
test_dls = [dev_dl, test_net_dl, test_meeting_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
lexicon=lexicon,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
|
@ -23,6 +23,8 @@ from .checkpoint import (
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
)
|
||||
|
||||
from .context_graph import ContextGraph, ContextState
|
||||
|
||||
from .decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
|
412
icefall/context_graph.py
Normal file
412
icefall/context_graph.py
Normal file
@ -0,0 +1,412 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
|
||||
#
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from collections import deque
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
class ContextState:
|
||||
"""The state in ContextGraph"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
token: int,
|
||||
token_score: float,
|
||||
node_score: float,
|
||||
local_node_score: float,
|
||||
is_end: bool,
|
||||
):
|
||||
"""Create a ContextState.
|
||||
|
||||
Args:
|
||||
id:
|
||||
The node id, only for visualization now. A node is in [0, graph.num_nodes).
|
||||
The id of the root node is always 0.
|
||||
token:
|
||||
The token id.
|
||||
score:
|
||||
The bonus for each token during decoding, which will hopefully
|
||||
boost the token up to survive beam search.
|
||||
node_score:
|
||||
The accumulated bonus from root of graph to current node, it will be
|
||||
used to calculate the score for fail arc.
|
||||
local_node_score:
|
||||
The accumulated bonus from last ``end_node``(node with is_end true)
|
||||
to current_node, it will be used to calculate the score for fail arc.
|
||||
Node: The local_node_score of a ``end_node`` is 0.
|
||||
is_end:
|
||||
True if current token is the end of a context.
|
||||
"""
|
||||
self.id = id
|
||||
self.token = token
|
||||
self.token_score = token_score
|
||||
self.node_score = node_score
|
||||
self.local_node_score = local_node_score
|
||||
self.is_end = is_end
|
||||
self.next = {}
|
||||
self.fail = None
|
||||
self.output = None
|
||||
|
||||
|
||||
class ContextGraph:
|
||||
"""The ContextGraph is modified from Aho-Corasick which is mainly
|
||||
a Trie with a fail arc for each node.
|
||||
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for more details
|
||||
of Aho-Corasick algorithm.
|
||||
|
||||
A ContextGraph contains some words / phrases that we expect to boost their
|
||||
scores during decoding. If the substring of a decoded sequence matches the word / phrase
|
||||
in the ContextGraph, we will give the decoded sequence a bonus to make it survive
|
||||
beam search.
|
||||
"""
|
||||
|
||||
def __init__(self, context_score: float):
|
||||
"""Initialize a ContextGraph with the given ``context_score``.
|
||||
|
||||
A root node will be created (**NOTE:** the token of root is hardcoded to -1).
|
||||
|
||||
Args:
|
||||
context_score:
|
||||
The bonus score for each token(note: NOT for each word/phrase, it means longer
|
||||
word/phrase will have larger bonus score, they have to be matched though).
|
||||
"""
|
||||
self.context_score = context_score
|
||||
self.num_nodes = 0
|
||||
self.root = ContextState(
|
||||
id=self.num_nodes,
|
||||
token=-1,
|
||||
token_score=0,
|
||||
node_score=0,
|
||||
local_node_score=0,
|
||||
is_end=False,
|
||||
)
|
||||
self.root.fail = self.root
|
||||
|
||||
def _fill_fail_output(self):
|
||||
"""This function fills the fail arc for each trie node, it can be computed
|
||||
in linear time by performing a breadth-first search starting from the root.
|
||||
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
|
||||
details of the algorithm.
|
||||
"""
|
||||
queue = deque()
|
||||
for token, node in self.root.next.items():
|
||||
node.fail = self.root
|
||||
queue.append(node)
|
||||
while queue:
|
||||
current_node = queue.popleft()
|
||||
for token, node in current_node.next.items():
|
||||
fail = current_node.fail
|
||||
if token in fail.next:
|
||||
fail = fail.next[token]
|
||||
else:
|
||||
fail = fail.fail
|
||||
while token not in fail.next:
|
||||
fail = fail.fail
|
||||
if fail.token == -1: # root
|
||||
break
|
||||
if token in fail.next:
|
||||
fail = fail.next[token]
|
||||
node.fail = fail
|
||||
# fill the output arc
|
||||
output = node.fail
|
||||
while not output.is_end:
|
||||
output = output.fail
|
||||
if output.token == -1: # root
|
||||
output = None
|
||||
break
|
||||
node.output = output
|
||||
queue.append(node)
|
||||
|
||||
def build(self, token_ids: List[List[int]]):
|
||||
"""Build the ContextGraph from a list of token list.
|
||||
It first build a trie from the given token lists, then fill the fail arc
|
||||
for each trie node.
|
||||
|
||||
See https://en.wikipedia.org/wiki/Trie for how to build a trie.
|
||||
|
||||
Args:
|
||||
token_ids:
|
||||
The given token lists to build the ContextGraph, it is a list of token list,
|
||||
each token list contains the token ids for a word/phrase. The token id
|
||||
could be an id of a char (modeling with single Chinese char) or an id
|
||||
of a BPE (modeling with BPEs).
|
||||
"""
|
||||
for tokens in token_ids:
|
||||
node = self.root
|
||||
for i, token in enumerate(tokens):
|
||||
if token not in node.next:
|
||||
self.num_nodes += 1
|
||||
is_end = i == len(tokens) - 1
|
||||
node.next[token] = ContextState(
|
||||
id=self.num_nodes,
|
||||
token=token,
|
||||
token_score=self.context_score,
|
||||
node_score=node.node_score + self.context_score,
|
||||
local_node_score=0
|
||||
if is_end
|
||||
else (node.local_node_score + self.context_score),
|
||||
is_end=is_end,
|
||||
)
|
||||
node = node.next[token]
|
||||
self._fill_fail_output()
|
||||
|
||||
def forward_one_step(
|
||||
self, state: ContextState, token: int
|
||||
) -> Tuple[float, ContextState]:
|
||||
"""Search the graph with given state and token.
|
||||
|
||||
Args:
|
||||
state:
|
||||
The given token containing trie node to start.
|
||||
token:
|
||||
The given token.
|
||||
|
||||
Returns:
|
||||
Return a tuple of score and next state.
|
||||
"""
|
||||
node = None
|
||||
score = 0
|
||||
# token matched
|
||||
if token in state.next:
|
||||
node = state.next[token]
|
||||
score = node.token_score
|
||||
if state.is_end:
|
||||
score += state.node_score
|
||||
else:
|
||||
# token not matched
|
||||
# We will trace along the fail arc until it matches the token or reaching
|
||||
# root of the graph.
|
||||
node = state.fail
|
||||
while token not in node.next:
|
||||
node = node.fail
|
||||
if node.token == -1: # root
|
||||
break
|
||||
|
||||
if token in node.next:
|
||||
node = node.next[token]
|
||||
|
||||
# The score of the fail path
|
||||
score = node.node_score - state.local_node_score
|
||||
assert node is not None
|
||||
matched_score = 0
|
||||
output = node.output
|
||||
while output is not None:
|
||||
matched_score += output.node_score
|
||||
output = output.output
|
||||
return (score + matched_score, node)
|
||||
|
||||
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
||||
"""When reaching the end of the decoded sequence, we need to finalize
|
||||
the matching, the purpose is to subtract the added bonus score for the
|
||||
state that is not the end of a word/phrase.
|
||||
|
||||
Args:
|
||||
state:
|
||||
The given state(trie node).
|
||||
|
||||
Returns:
|
||||
Return a tuple of score and next state. If state is the end of a word/phrase
|
||||
the score is zero, otherwise the score is the score of a implicit fail arc
|
||||
to root. The next state is always root.
|
||||
"""
|
||||
# The score of the fail arc
|
||||
score = -state.node_score
|
||||
if state.is_end:
|
||||
score = 0
|
||||
return (score, self.root)
|
||||
|
||||
def draw(
|
||||
self,
|
||||
title: Optional[str] = None,
|
||||
filename: Optional[str] = "",
|
||||
symbol_table: Optional[Dict[int, str]] = None,
|
||||
) -> "Digraph": # noqa
|
||||
|
||||
"""Visualize a ContextGraph via graphviz.
|
||||
|
||||
Render ContextGraph as an image via graphviz, and return the Digraph object;
|
||||
and optionally save to file `filename`.
|
||||
`filename` must have a suffix that graphviz understands, such as
|
||||
`pdf`, `svg` or `png`.
|
||||
|
||||
Note:
|
||||
You need to install graphviz to use this function::
|
||||
|
||||
pip install graphviz
|
||||
|
||||
Args:
|
||||
title:
|
||||
Title to be displayed in image, e.g. 'A simple FSA example'
|
||||
filename:
|
||||
Filename to (optionally) save to, e.g. 'foo.png', 'foo.svg',
|
||||
'foo.png' (must have a suffix that graphviz understands).
|
||||
symbol_table:
|
||||
Map the token ids to symbols.
|
||||
Returns:
|
||||
A Diagraph from grahpviz.
|
||||
"""
|
||||
|
||||
try:
|
||||
import graphviz
|
||||
except Exception:
|
||||
print("You cannot use `to_dot` unless the graphviz package is installed.")
|
||||
raise
|
||||
|
||||
graph_attr = {
|
||||
"rankdir": "LR",
|
||||
"size": "8.5,11",
|
||||
"center": "1",
|
||||
"orientation": "Portrait",
|
||||
"ranksep": "0.4",
|
||||
"nodesep": "0.25",
|
||||
}
|
||||
if title is not None:
|
||||
graph_attr["label"] = title
|
||||
|
||||
default_node_attr = {
|
||||
"shape": "circle",
|
||||
"style": "bold",
|
||||
"fontsize": "14",
|
||||
}
|
||||
|
||||
final_state_attr = {
|
||||
"shape": "doublecircle",
|
||||
"style": "bold",
|
||||
"fontsize": "14",
|
||||
}
|
||||
|
||||
final_state = -1
|
||||
dot = graphviz.Digraph(name="Context Graph", graph_attr=graph_attr)
|
||||
|
||||
seen = set()
|
||||
queue = deque()
|
||||
queue.append(self.root)
|
||||
# root id is always 0
|
||||
dot.node("0", label="0", **default_node_attr)
|
||||
dot.edge("0", "0", color="red")
|
||||
seen.add(0)
|
||||
|
||||
while len(queue):
|
||||
current_node = queue.popleft()
|
||||
for token, node in current_node.next.items():
|
||||
if node.id not in seen:
|
||||
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
|
||||
local_node_score = f"{node.local_node_score:.2f}".rstrip(
|
||||
"0"
|
||||
).rstrip(".")
|
||||
label = f"{node.id}/({node_score},{local_node_score})"
|
||||
if node.is_end:
|
||||
dot.node(str(node.id), label=label, **final_state_attr)
|
||||
else:
|
||||
dot.node(str(node.id), label=label, **default_node_attr)
|
||||
seen.add(node.id)
|
||||
weight = f"{node.token_score:.2f}".rstrip("0").rstrip(".")
|
||||
label = str(token) if symbol_table is None else symbol_table[token]
|
||||
dot.edge(str(current_node.id), str(node.id), label=f"{label}/{weight}")
|
||||
dot.edge(
|
||||
str(node.id),
|
||||
str(node.fail.id),
|
||||
color="red",
|
||||
)
|
||||
if node.output is not None:
|
||||
dot.edge(
|
||||
str(node.id),
|
||||
str(node.output.id),
|
||||
color="green",
|
||||
)
|
||||
queue.append(node)
|
||||
|
||||
if filename:
|
||||
_, extension = os.path.splitext(filename)
|
||||
if extension == "" or extension[0] != ".":
|
||||
raise ValueError(
|
||||
"Filename needs to have a suffix like .png, .pdf, .svg: {}".format(
|
||||
filename
|
||||
)
|
||||
)
|
||||
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
temp_fn = dot.render(
|
||||
filename="temp",
|
||||
directory=tmp_dir,
|
||||
format=extension[1:],
|
||||
cleanup=True,
|
||||
)
|
||||
|
||||
shutil.move(temp_fn, filename)
|
||||
|
||||
return dot
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
contexts_str = [
|
||||
"S",
|
||||
"HE",
|
||||
"SHE",
|
||||
"SHELL",
|
||||
"HIS",
|
||||
"HERS",
|
||||
"HELLO",
|
||||
"THIS",
|
||||
"THEM",
|
||||
]
|
||||
contexts = []
|
||||
for s in contexts_str:
|
||||
contexts.append([ord(x) for x in s])
|
||||
|
||||
context_graph = ContextGraph(context_score=1)
|
||||
context_graph.build(contexts)
|
||||
|
||||
symbol_table = {}
|
||||
for contexts in contexts_str:
|
||||
for s in contexts:
|
||||
symbol_table[ord(s)] = s
|
||||
|
||||
context_graph.draw(
|
||||
title="Graph for: " + " / ".join(contexts_str),
|
||||
filename="context_graph.pdf",
|
||||
symbol_table=symbol_table,
|
||||
)
|
||||
|
||||
queries = {
|
||||
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
|
||||
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
|
||||
"HISHE": 9, # "HIS", "S", "SHE", "HE"
|
||||
"SHED": 6, # "S", "SHE", "HE"
|
||||
"HELL": 2, # "HE"
|
||||
"HELLO": 7, # "HE", "HELLO"
|
||||
"DHRHISQ": 4, # "HIS", "S"
|
||||
"THEN": 2, # "HE"
|
||||
}
|
||||
for query, expected_score in queries.items():
|
||||
total_scores = 0
|
||||
state = context_graph.root
|
||||
for q in query:
|
||||
score, state = context_graph.forward_one_step(state, ord(q))
|
||||
total_scores += score
|
||||
score, state = context_graph.finalize(state)
|
||||
assert state.token == -1, state.token
|
||||
total_scores += score
|
||||
assert total_scores == expected_score, (
|
||||
total_scores,
|
||||
expected_score,
|
||||
query,
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user