mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Add Context biasing (#1038)
* Add context biasing for librispeech * Add context biasing for wenetspeech * fix bugs * Implement Aho-Corasick context graph * fix some bugs * Fixes to forward_one_step; add draw to context graph * add output arc; fix black * Fix wenetspeech tokenizer * Minor fixes to the decode.py
This commit is contained in:
parent
ca60ced213
commit
ba257efbcd
@ -58,6 +58,7 @@ Usage:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -76,6 +77,8 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall import ContextGraph, LmScorer, NgramLm
|
||||||
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -211,6 +214,26 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-score",
|
||||||
|
type=float,
|
||||||
|
default=2,
|
||||||
|
help="""
|
||||||
|
The bonus score of each token for the context biasing words/phrases.
|
||||||
|
Used only when --decoding_method is modified_beam_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The path of the context biasing lists, one word/phrase each line
|
||||||
|
Used only when --decoding_method is modified_beam_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -222,6 +245,7 @@ def decode_one_batch(
|
|||||||
token_table: k2.SymbolTable,
|
token_table: k2.SymbolTable,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -285,6 +309,7 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hyp_tokens = []
|
hyp_tokens = []
|
||||||
@ -324,7 +349,12 @@ def decode_one_batch(
|
|||||||
): hyps
|
): hyps
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
key = f"beam_size_{params.beam_size}"
|
||||||
|
if params.has_contexts:
|
||||||
|
key += f"-context-score-{params.context_score}"
|
||||||
|
else:
|
||||||
|
key += "-no-context-words"
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -333,6 +363,7 @@ def decode_dataset(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
token_table: k2.SymbolTable,
|
token_table: k2.SymbolTable,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -377,6 +408,7 @@ def decode_dataset(
|
|||||||
model=model,
|
model=model,
|
||||||
token_table=token_table,
|
token_table=token_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
context_graph=context_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -407,16 +439,17 @@ def save_results(
|
|||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
# we compute CER for aishell dataset.
|
||||||
|
results_char = []
|
||||||
|
for res in results:
|
||||||
|
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||||
|
|
||||||
|
store_transcripts(filename=recog_path, texts=results_char)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
# we compute CER for aishell dataset.
|
|
||||||
results_char = []
|
|
||||||
for res in results:
|
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||||
@ -457,6 +490,12 @@ def main():
|
|||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
params.has_contexts = True
|
||||||
|
else:
|
||||||
|
params.has_contexts = False
|
||||||
|
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
@ -470,6 +509,10 @@ def main():
|
|||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
if params.has_contexts:
|
||||||
|
params.suffix += f"-context-score-{params.context_score}"
|
||||||
|
else:
|
||||||
|
params.suffix += "-no-contexts-words"
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -490,6 +533,11 @@ def main():
|
|||||||
params.blank_id = 0
|
params.blank_id = 0
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
|
lexicon=lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
@ -586,6 +634,19 @@ def main():
|
|||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
|
||||||
|
if params.decoding_method == "modified_beam_search":
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
contexts_text = []
|
||||||
|
for line in open(params.context_file).readlines():
|
||||||
|
contexts_text.append(line.strip())
|
||||||
|
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||||
|
context_graph = ContextGraph(params.context_score)
|
||||||
|
context_graph.build(contexts)
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -608,6 +669,7 @@ def main():
|
|||||||
model=model,
|
model=model,
|
||||||
token_table=lexicon.token_table,
|
token_table=lexicon.token_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -24,7 +24,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
from icefall import NgramLm, NgramLmStateCost
|
from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
|
||||||
from icefall.decode import Nbest, one_best_decoding
|
from icefall.decode import Nbest, one_best_decoding
|
||||||
from icefall.lm_wrapper import LmScorer
|
from icefall.lm_wrapper import LmScorer
|
||||||
from icefall.rnn_lm.model import RnnLmModel
|
from icefall.rnn_lm.model import RnnLmModel
|
||||||
@ -765,6 +765,9 @@ class Hypothesis:
|
|||||||
# N-gram LM state
|
# N-gram LM state
|
||||||
state_cost: Optional[NgramLmStateCost] = None
|
state_cost: Optional[NgramLmStateCost] = None
|
||||||
|
|
||||||
|
# Context graph state
|
||||||
|
context_state: Optional[ContextState] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self) -> str:
|
def key(self) -> str:
|
||||||
"""Return a string representation of self.ys"""
|
"""Return a string representation of self.ys"""
|
||||||
@ -917,6 +920,7 @@ def modified_beam_search(
|
|||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
return_timestamps: bool = False,
|
return_timestamps: bool = False,
|
||||||
@ -968,6 +972,7 @@ def modified_beam_search(
|
|||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[blank_id] * context_size,
|
ys=[blank_id] * context_size,
|
||||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
context_state=None if context_graph is None else context_graph.root,
|
||||||
timestamp=[],
|
timestamp=[],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -990,6 +995,7 @@ def modified_beam_search(
|
|||||||
hyps_shape = get_hyps_shape(B).to(device)
|
hyps_shape = get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
A = [list(b) for b in B]
|
A = [list(b) for b in B]
|
||||||
|
|
||||||
B = [HypothesisList() for _ in range(batch_size)]
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
|
||||||
ys_log_probs = torch.cat(
|
ys_log_probs = torch.cat(
|
||||||
@ -1047,21 +1053,51 @@ def modified_beam_search(
|
|||||||
for k in range(len(topk_hyp_indexes)):
|
for k in range(len(topk_hyp_indexes)):
|
||||||
hyp_idx = topk_hyp_indexes[k]
|
hyp_idx = topk_hyp_indexes[k]
|
||||||
hyp = A[i][hyp_idx]
|
hyp = A[i][hyp_idx]
|
||||||
|
|
||||||
new_ys = hyp.ys[:]
|
new_ys = hyp.ys[:]
|
||||||
new_token = topk_token_indexes[k]
|
new_token = topk_token_indexes[k]
|
||||||
new_timestamp = hyp.timestamp[:]
|
new_timestamp = hyp.timestamp[:]
|
||||||
|
context_score = 0
|
||||||
|
new_context_state = None if context_graph is None else hyp.context_state
|
||||||
if new_token not in (blank_id, unk_id):
|
if new_token not in (blank_id, unk_id):
|
||||||
new_ys.append(new_token)
|
new_ys.append(new_token)
|
||||||
new_timestamp.append(t)
|
new_timestamp.append(t)
|
||||||
|
if context_graph is not None:
|
||||||
|
(
|
||||||
|
context_score,
|
||||||
|
new_context_state,
|
||||||
|
) = context_graph.forward_one_step(hyp.context_state, new_token)
|
||||||
|
|
||||||
|
new_log_prob = topk_log_probs[k] + context_score
|
||||||
|
|
||||||
new_log_prob = topk_log_probs[k]
|
|
||||||
new_hyp = Hypothesis(
|
new_hyp = Hypothesis(
|
||||||
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
ys=new_ys,
|
||||||
|
log_prob=new_log_prob,
|
||||||
|
timestamp=new_timestamp,
|
||||||
|
context_state=new_context_state,
|
||||||
)
|
)
|
||||||
B[i].add(new_hyp)
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
B = B + finalized_B
|
B = B + finalized_B
|
||||||
|
|
||||||
|
# finalize context_state, if the matched contexts do not reach final state
|
||||||
|
# we need to add the score on the corresponding backoff arc
|
||||||
|
if context_graph is not None:
|
||||||
|
finalized_B = [HypothesisList() for _ in range(len(B))]
|
||||||
|
for i, hyps in enumerate(B):
|
||||||
|
for hyp in list(hyps):
|
||||||
|
context_score, new_context_state = context_graph.finalize(
|
||||||
|
hyp.context_state
|
||||||
|
)
|
||||||
|
finalized_B[i].add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=hyp.ys,
|
||||||
|
log_prob=hyp.log_prob + context_score,
|
||||||
|
timestamp=hyp.timestamp,
|
||||||
|
context_state=new_context_state,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
B = finalized_B
|
||||||
|
|
||||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
|
|
||||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
@ -125,6 +125,7 @@ For example:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -146,6 +147,7 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall import ContextGraph
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -353,6 +355,27 @@ def get_parser():
|
|||||||
Used only when the decoding method is fast_beam_search_nbest,
|
Used only when the decoding method is fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-score",
|
||||||
|
type=float,
|
||||||
|
default=2,
|
||||||
|
help="""
|
||||||
|
The bonus score of each token for the context biasing words/phrases.
|
||||||
|
Used only when --decoding_method is modified_beam_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The path of the context biasing lists, one word/phrase each line
|
||||||
|
Used only when --decoding_method is modified_beam_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -365,6 +388,7 @@ def decode_one_batch(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -494,6 +518,7 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
context_graph=context_graph,
|
||||||
return_timestamps=True,
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -548,7 +573,12 @@ def decode_one_batch(
|
|||||||
|
|
||||||
return {key: (hyps, timestamps)}
|
return {key: (hyps, timestamps)}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
|
key = f"beam_size_{params.beam_size}"
|
||||||
|
if params.has_contexts:
|
||||||
|
key += f"-context-score-{params.context_score}"
|
||||||
|
else:
|
||||||
|
key += "-no-context-words"
|
||||||
|
return {key: (hyps, timestamps)}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -558,6 +588,7 @@ def decode_dataset(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -622,6 +653,7 @@ def decode_dataset(
|
|||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||||
@ -728,6 +760,12 @@ def main():
|
|||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
params.has_contexts = True
|
||||||
|
else:
|
||||||
|
params.has_contexts = False
|
||||||
|
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
@ -750,6 +788,10 @@ def main():
|
|||||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
if params.has_contexts:
|
||||||
|
params.suffix += f"-context-score-{params.context_score}"
|
||||||
|
else:
|
||||||
|
params.suffix += "-no-context-words"
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -881,6 +923,18 @@ def main():
|
|||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
word_table = None
|
word_table = None
|
||||||
|
|
||||||
|
if params.decoding_method == "modified_beam_search":
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
contexts = []
|
||||||
|
for line in open(params.context_file).readlines():
|
||||||
|
contexts.append(line.strip())
|
||||||
|
context_graph = ContextGraph(params.context_score)
|
||||||
|
context_graph.build(sp.encode(contexts))
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -905,6 +959,7 @@ def main():
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -106,7 +106,7 @@ class WenetSpeechAsrDataModule:
|
|||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-buckets",
|
"--num-buckets",
|
||||||
type=int,
|
type=int,
|
||||||
default=300,
|
default=30,
|
||||||
help="The number of buckets for the DynamicBucketingSampler"
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
"(you might want to increase it for larger datasets).",
|
"(you might want to increase it for larger datasets).",
|
||||||
)
|
)
|
||||||
@ -364,7 +364,7 @@ class WenetSpeechAsrDataModule:
|
|||||||
return valid_dl
|
return valid_dl
|
||||||
|
|
||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
logging.debug("About to create test dataset")
|
logging.info("About to create test dataset")
|
||||||
test = K2SpeechRecognitionDataset(
|
test = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
if self.args.on_the_fly_feats
|
if self.args.on_the_fly_feats
|
||||||
|
@ -92,7 +92,7 @@ When training with the L subset, the streaming usage:
|
|||||||
--causal-convolution 1 \
|
--causal-convolution 1 \
|
||||||
--decode-chunk-size 16 \
|
--decode-chunk-size 16 \
|
||||||
--left-context 64
|
--left-context 64
|
||||||
|
|
||||||
(4) modified beam search with RNNLM shallow fusion
|
(4) modified beam search with RNNLM shallow fusion
|
||||||
./pruned_transducer_stateless5/decode.py \
|
./pruned_transducer_stateless5/decode.py \
|
||||||
--epoch 35 \
|
--epoch 35 \
|
||||||
@ -112,8 +112,10 @@ When training with the L subset, the streaming usage:
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import glob
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -133,7 +135,8 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall import LmScorer, NgramLm
|
from icefall import ContextGraph, LmScorer, NgramLm
|
||||||
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -307,6 +310,26 @@ def get_parser():
|
|||||||
help="left context can be seen during decoding (in frames after subsampling)",
|
help="left context can be seen during decoding (in frames after subsampling)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-score",
|
||||||
|
type=float,
|
||||||
|
default=2,
|
||||||
|
help="""
|
||||||
|
The bonus score of each token for the context biasing words/phrases.
|
||||||
|
Used only when --decoding_method is modified_beam_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The path of the context biasing lists, one word/phrase each line
|
||||||
|
Used only when --decoding_method is modified_beam_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-shallow-fusion",
|
"--use-shallow-fusion",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -362,6 +385,7 @@ def decode_one_batch(
|
|||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
ngram_lm: Optional[NgramLm] = None,
|
ngram_lm: Optional[NgramLm] = None,
|
||||||
ngram_lm_scale: float = 1.0,
|
ngram_lm_scale: float = 1.0,
|
||||||
LM: Optional[LmScorer] = None,
|
LM: Optional[LmScorer] = None,
|
||||||
@ -402,14 +426,13 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
feature_lens += params.left_context
|
|
||||||
feature = torch.nn.functional.pad(
|
|
||||||
feature,
|
|
||||||
pad=(0, 0, 0, params.left_context),
|
|
||||||
value=LOG_EPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
|
feature_lens += params.left_context
|
||||||
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, params.left_context),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
@ -448,6 +471,7 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
@ -509,7 +533,12 @@ def decode_one_batch(
|
|||||||
): hyps
|
): hyps
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
key = f"beam_size_{params.beam_size}"
|
||||||
|
if params.has_contexts:
|
||||||
|
key += f"-context-score-{params.context_score}"
|
||||||
|
else:
|
||||||
|
key += "-no-context-words"
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -518,6 +547,7 @@ def decode_dataset(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
ngram_lm: Optional[NgramLm] = None,
|
ngram_lm: Optional[NgramLm] = None,
|
||||||
ngram_lm_scale: float = 1.0,
|
ngram_lm_scale: float = 1.0,
|
||||||
LM: Optional[LmScorer] = None,
|
LM: Optional[LmScorer] = None,
|
||||||
@ -567,6 +597,7 @@ def decode_dataset(
|
|||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
context_graph=context_graph,
|
||||||
ngram_lm=ngram_lm,
|
ngram_lm=ngram_lm,
|
||||||
ngram_lm_scale=ngram_lm_scale,
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
@ -646,6 +677,12 @@ def main():
|
|||||||
"modified_beam_search_lm_shallow_fusion",
|
"modified_beam_search_lm_shallow_fusion",
|
||||||
"modified_beam_search_LODR",
|
"modified_beam_search_LODR",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
params.has_contexts = True
|
||||||
|
else:
|
||||||
|
params.has_contexts = False
|
||||||
|
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
@ -655,6 +692,10 @@ def main():
|
|||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += f"-beam-{params.beam_size}"
|
||||||
|
if params.has_contexts:
|
||||||
|
params.suffix += f"-context-score-{params.context_score}"
|
||||||
|
else:
|
||||||
|
params.suffix += "-no-contexts-words"
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -684,11 +725,15 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
params.blank_id = lexicon.token_table["<blk>"]
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
|
lexicon=lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
assert (
|
assert (
|
||||||
params.causal_convolution
|
params.causal_convolution
|
||||||
@ -816,6 +861,19 @@ def main():
|
|||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
|
||||||
|
if params.decoding_method == "modified_beam_search":
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
contexts_text = []
|
||||||
|
for line in open(params.context_file).readlines():
|
||||||
|
contexts_text.append(line.strip())
|
||||||
|
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||||
|
context_graph = ContextGraph(params.context_score)
|
||||||
|
context_graph.build(contexts)
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -833,15 +891,16 @@ def main():
|
|||||||
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
||||||
|
|
||||||
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
||||||
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
|
test_dls = [dev_dl, test_net_dl, test_meeting_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
context_graph=context_graph,
|
||||||
ngram_lm=ngram_lm,
|
ngram_lm=ngram_lm,
|
||||||
ngram_lm_scale=ngram_lm_scale,
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
|
@ -23,6 +23,8 @@ from .checkpoint import (
|
|||||||
save_checkpoint_with_global_batch_idx,
|
save_checkpoint_with_global_batch_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .context_graph import ContextGraph, ContextState
|
||||||
|
|
||||||
from .decode import (
|
from .decode import (
|
||||||
get_lattice,
|
get_lattice,
|
||||||
nbest_decoding,
|
nbest_decoding,
|
||||||
|
412
icefall/context_graph.py
Normal file
412
icefall/context_graph.py
Normal file
@ -0,0 +1,412 @@
|
|||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from collections import deque
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class ContextState:
|
||||||
|
"""The state in ContextGraph"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: int,
|
||||||
|
token: int,
|
||||||
|
token_score: float,
|
||||||
|
node_score: float,
|
||||||
|
local_node_score: float,
|
||||||
|
is_end: bool,
|
||||||
|
):
|
||||||
|
"""Create a ContextState.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id:
|
||||||
|
The node id, only for visualization now. A node is in [0, graph.num_nodes).
|
||||||
|
The id of the root node is always 0.
|
||||||
|
token:
|
||||||
|
The token id.
|
||||||
|
score:
|
||||||
|
The bonus for each token during decoding, which will hopefully
|
||||||
|
boost the token up to survive beam search.
|
||||||
|
node_score:
|
||||||
|
The accumulated bonus from root of graph to current node, it will be
|
||||||
|
used to calculate the score for fail arc.
|
||||||
|
local_node_score:
|
||||||
|
The accumulated bonus from last ``end_node``(node with is_end true)
|
||||||
|
to current_node, it will be used to calculate the score for fail arc.
|
||||||
|
Node: The local_node_score of a ``end_node`` is 0.
|
||||||
|
is_end:
|
||||||
|
True if current token is the end of a context.
|
||||||
|
"""
|
||||||
|
self.id = id
|
||||||
|
self.token = token
|
||||||
|
self.token_score = token_score
|
||||||
|
self.node_score = node_score
|
||||||
|
self.local_node_score = local_node_score
|
||||||
|
self.is_end = is_end
|
||||||
|
self.next = {}
|
||||||
|
self.fail = None
|
||||||
|
self.output = None
|
||||||
|
|
||||||
|
|
||||||
|
class ContextGraph:
|
||||||
|
"""The ContextGraph is modified from Aho-Corasick which is mainly
|
||||||
|
a Trie with a fail arc for each node.
|
||||||
|
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for more details
|
||||||
|
of Aho-Corasick algorithm.
|
||||||
|
|
||||||
|
A ContextGraph contains some words / phrases that we expect to boost their
|
||||||
|
scores during decoding. If the substring of a decoded sequence matches the word / phrase
|
||||||
|
in the ContextGraph, we will give the decoded sequence a bonus to make it survive
|
||||||
|
beam search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, context_score: float):
|
||||||
|
"""Initialize a ContextGraph with the given ``context_score``.
|
||||||
|
|
||||||
|
A root node will be created (**NOTE:** the token of root is hardcoded to -1).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_score:
|
||||||
|
The bonus score for each token(note: NOT for each word/phrase, it means longer
|
||||||
|
word/phrase will have larger bonus score, they have to be matched though).
|
||||||
|
"""
|
||||||
|
self.context_score = context_score
|
||||||
|
self.num_nodes = 0
|
||||||
|
self.root = ContextState(
|
||||||
|
id=self.num_nodes,
|
||||||
|
token=-1,
|
||||||
|
token_score=0,
|
||||||
|
node_score=0,
|
||||||
|
local_node_score=0,
|
||||||
|
is_end=False,
|
||||||
|
)
|
||||||
|
self.root.fail = self.root
|
||||||
|
|
||||||
|
def _fill_fail_output(self):
|
||||||
|
"""This function fills the fail arc for each trie node, it can be computed
|
||||||
|
in linear time by performing a breadth-first search starting from the root.
|
||||||
|
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
|
||||||
|
details of the algorithm.
|
||||||
|
"""
|
||||||
|
queue = deque()
|
||||||
|
for token, node in self.root.next.items():
|
||||||
|
node.fail = self.root
|
||||||
|
queue.append(node)
|
||||||
|
while queue:
|
||||||
|
current_node = queue.popleft()
|
||||||
|
for token, node in current_node.next.items():
|
||||||
|
fail = current_node.fail
|
||||||
|
if token in fail.next:
|
||||||
|
fail = fail.next[token]
|
||||||
|
else:
|
||||||
|
fail = fail.fail
|
||||||
|
while token not in fail.next:
|
||||||
|
fail = fail.fail
|
||||||
|
if fail.token == -1: # root
|
||||||
|
break
|
||||||
|
if token in fail.next:
|
||||||
|
fail = fail.next[token]
|
||||||
|
node.fail = fail
|
||||||
|
# fill the output arc
|
||||||
|
output = node.fail
|
||||||
|
while not output.is_end:
|
||||||
|
output = output.fail
|
||||||
|
if output.token == -1: # root
|
||||||
|
output = None
|
||||||
|
break
|
||||||
|
node.output = output
|
||||||
|
queue.append(node)
|
||||||
|
|
||||||
|
def build(self, token_ids: List[List[int]]):
|
||||||
|
"""Build the ContextGraph from a list of token list.
|
||||||
|
It first build a trie from the given token lists, then fill the fail arc
|
||||||
|
for each trie node.
|
||||||
|
|
||||||
|
See https://en.wikipedia.org/wiki/Trie for how to build a trie.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids:
|
||||||
|
The given token lists to build the ContextGraph, it is a list of token list,
|
||||||
|
each token list contains the token ids for a word/phrase. The token id
|
||||||
|
could be an id of a char (modeling with single Chinese char) or an id
|
||||||
|
of a BPE (modeling with BPEs).
|
||||||
|
"""
|
||||||
|
for tokens in token_ids:
|
||||||
|
node = self.root
|
||||||
|
for i, token in enumerate(tokens):
|
||||||
|
if token not in node.next:
|
||||||
|
self.num_nodes += 1
|
||||||
|
is_end = i == len(tokens) - 1
|
||||||
|
node.next[token] = ContextState(
|
||||||
|
id=self.num_nodes,
|
||||||
|
token=token,
|
||||||
|
token_score=self.context_score,
|
||||||
|
node_score=node.node_score + self.context_score,
|
||||||
|
local_node_score=0
|
||||||
|
if is_end
|
||||||
|
else (node.local_node_score + self.context_score),
|
||||||
|
is_end=is_end,
|
||||||
|
)
|
||||||
|
node = node.next[token]
|
||||||
|
self._fill_fail_output()
|
||||||
|
|
||||||
|
def forward_one_step(
|
||||||
|
self, state: ContextState, token: int
|
||||||
|
) -> Tuple[float, ContextState]:
|
||||||
|
"""Search the graph with given state and token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state:
|
||||||
|
The given token containing trie node to start.
|
||||||
|
token:
|
||||||
|
The given token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a tuple of score and next state.
|
||||||
|
"""
|
||||||
|
node = None
|
||||||
|
score = 0
|
||||||
|
# token matched
|
||||||
|
if token in state.next:
|
||||||
|
node = state.next[token]
|
||||||
|
score = node.token_score
|
||||||
|
if state.is_end:
|
||||||
|
score += state.node_score
|
||||||
|
else:
|
||||||
|
# token not matched
|
||||||
|
# We will trace along the fail arc until it matches the token or reaching
|
||||||
|
# root of the graph.
|
||||||
|
node = state.fail
|
||||||
|
while token not in node.next:
|
||||||
|
node = node.fail
|
||||||
|
if node.token == -1: # root
|
||||||
|
break
|
||||||
|
|
||||||
|
if token in node.next:
|
||||||
|
node = node.next[token]
|
||||||
|
|
||||||
|
# The score of the fail path
|
||||||
|
score = node.node_score - state.local_node_score
|
||||||
|
assert node is not None
|
||||||
|
matched_score = 0
|
||||||
|
output = node.output
|
||||||
|
while output is not None:
|
||||||
|
matched_score += output.node_score
|
||||||
|
output = output.output
|
||||||
|
return (score + matched_score, node)
|
||||||
|
|
||||||
|
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
||||||
|
"""When reaching the end of the decoded sequence, we need to finalize
|
||||||
|
the matching, the purpose is to subtract the added bonus score for the
|
||||||
|
state that is not the end of a word/phrase.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state:
|
||||||
|
The given state(trie node).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a tuple of score and next state. If state is the end of a word/phrase
|
||||||
|
the score is zero, otherwise the score is the score of a implicit fail arc
|
||||||
|
to root. The next state is always root.
|
||||||
|
"""
|
||||||
|
# The score of the fail arc
|
||||||
|
score = -state.node_score
|
||||||
|
if state.is_end:
|
||||||
|
score = 0
|
||||||
|
return (score, self.root)
|
||||||
|
|
||||||
|
def draw(
|
||||||
|
self,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
filename: Optional[str] = "",
|
||||||
|
symbol_table: Optional[Dict[int, str]] = None,
|
||||||
|
) -> "Digraph": # noqa
|
||||||
|
|
||||||
|
"""Visualize a ContextGraph via graphviz.
|
||||||
|
|
||||||
|
Render ContextGraph as an image via graphviz, and return the Digraph object;
|
||||||
|
and optionally save to file `filename`.
|
||||||
|
`filename` must have a suffix that graphviz understands, such as
|
||||||
|
`pdf`, `svg` or `png`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
You need to install graphviz to use this function::
|
||||||
|
|
||||||
|
pip install graphviz
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title:
|
||||||
|
Title to be displayed in image, e.g. 'A simple FSA example'
|
||||||
|
filename:
|
||||||
|
Filename to (optionally) save to, e.g. 'foo.png', 'foo.svg',
|
||||||
|
'foo.png' (must have a suffix that graphviz understands).
|
||||||
|
symbol_table:
|
||||||
|
Map the token ids to symbols.
|
||||||
|
Returns:
|
||||||
|
A Diagraph from grahpviz.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
import graphviz
|
||||||
|
except Exception:
|
||||||
|
print("You cannot use `to_dot` unless the graphviz package is installed.")
|
||||||
|
raise
|
||||||
|
|
||||||
|
graph_attr = {
|
||||||
|
"rankdir": "LR",
|
||||||
|
"size": "8.5,11",
|
||||||
|
"center": "1",
|
||||||
|
"orientation": "Portrait",
|
||||||
|
"ranksep": "0.4",
|
||||||
|
"nodesep": "0.25",
|
||||||
|
}
|
||||||
|
if title is not None:
|
||||||
|
graph_attr["label"] = title
|
||||||
|
|
||||||
|
default_node_attr = {
|
||||||
|
"shape": "circle",
|
||||||
|
"style": "bold",
|
||||||
|
"fontsize": "14",
|
||||||
|
}
|
||||||
|
|
||||||
|
final_state_attr = {
|
||||||
|
"shape": "doublecircle",
|
||||||
|
"style": "bold",
|
||||||
|
"fontsize": "14",
|
||||||
|
}
|
||||||
|
|
||||||
|
final_state = -1
|
||||||
|
dot = graphviz.Digraph(name="Context Graph", graph_attr=graph_attr)
|
||||||
|
|
||||||
|
seen = set()
|
||||||
|
queue = deque()
|
||||||
|
queue.append(self.root)
|
||||||
|
# root id is always 0
|
||||||
|
dot.node("0", label="0", **default_node_attr)
|
||||||
|
dot.edge("0", "0", color="red")
|
||||||
|
seen.add(0)
|
||||||
|
|
||||||
|
while len(queue):
|
||||||
|
current_node = queue.popleft()
|
||||||
|
for token, node in current_node.next.items():
|
||||||
|
if node.id not in seen:
|
||||||
|
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
|
||||||
|
local_node_score = f"{node.local_node_score:.2f}".rstrip(
|
||||||
|
"0"
|
||||||
|
).rstrip(".")
|
||||||
|
label = f"{node.id}/({node_score},{local_node_score})"
|
||||||
|
if node.is_end:
|
||||||
|
dot.node(str(node.id), label=label, **final_state_attr)
|
||||||
|
else:
|
||||||
|
dot.node(str(node.id), label=label, **default_node_attr)
|
||||||
|
seen.add(node.id)
|
||||||
|
weight = f"{node.token_score:.2f}".rstrip("0").rstrip(".")
|
||||||
|
label = str(token) if symbol_table is None else symbol_table[token]
|
||||||
|
dot.edge(str(current_node.id), str(node.id), label=f"{label}/{weight}")
|
||||||
|
dot.edge(
|
||||||
|
str(node.id),
|
||||||
|
str(node.fail.id),
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
if node.output is not None:
|
||||||
|
dot.edge(
|
||||||
|
str(node.id),
|
||||||
|
str(node.output.id),
|
||||||
|
color="green",
|
||||||
|
)
|
||||||
|
queue.append(node)
|
||||||
|
|
||||||
|
if filename:
|
||||||
|
_, extension = os.path.splitext(filename)
|
||||||
|
if extension == "" or extension[0] != ".":
|
||||||
|
raise ValueError(
|
||||||
|
"Filename needs to have a suffix like .png, .pdf, .svg: {}".format(
|
||||||
|
filename
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
temp_fn = dot.render(
|
||||||
|
filename="temp",
|
||||||
|
directory=tmp_dir,
|
||||||
|
format=extension[1:],
|
||||||
|
cleanup=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
shutil.move(temp_fn, filename)
|
||||||
|
|
||||||
|
return dot
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
contexts_str = [
|
||||||
|
"S",
|
||||||
|
"HE",
|
||||||
|
"SHE",
|
||||||
|
"SHELL",
|
||||||
|
"HIS",
|
||||||
|
"HERS",
|
||||||
|
"HELLO",
|
||||||
|
"THIS",
|
||||||
|
"THEM",
|
||||||
|
]
|
||||||
|
contexts = []
|
||||||
|
for s in contexts_str:
|
||||||
|
contexts.append([ord(x) for x in s])
|
||||||
|
|
||||||
|
context_graph = ContextGraph(context_score=1)
|
||||||
|
context_graph.build(contexts)
|
||||||
|
|
||||||
|
symbol_table = {}
|
||||||
|
for contexts in contexts_str:
|
||||||
|
for s in contexts:
|
||||||
|
symbol_table[ord(s)] = s
|
||||||
|
|
||||||
|
context_graph.draw(
|
||||||
|
title="Graph for: " + " / ".join(contexts_str),
|
||||||
|
filename="context_graph.pdf",
|
||||||
|
symbol_table=symbol_table,
|
||||||
|
)
|
||||||
|
|
||||||
|
queries = {
|
||||||
|
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
|
||||||
|
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
|
||||||
|
"HISHE": 9, # "HIS", "S", "SHE", "HE"
|
||||||
|
"SHED": 6, # "S", "SHE", "HE"
|
||||||
|
"HELL": 2, # "HE"
|
||||||
|
"HELLO": 7, # "HE", "HELLO"
|
||||||
|
"DHRHISQ": 4, # "HIS", "S"
|
||||||
|
"THEN": 2, # "HE"
|
||||||
|
}
|
||||||
|
for query, expected_score in queries.items():
|
||||||
|
total_scores = 0
|
||||||
|
state = context_graph.root
|
||||||
|
for q in query:
|
||||||
|
score, state = context_graph.forward_one_step(state, ord(q))
|
||||||
|
total_scores += score
|
||||||
|
score, state = context_graph.finalize(state)
|
||||||
|
assert state.token == -1, state.token
|
||||||
|
total_scores += score
|
||||||
|
assert total_scores == expected_score, (
|
||||||
|
total_scores,
|
||||||
|
expected_score,
|
||||||
|
query,
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user