Add Context biasing (#1038)

* Add context biasing for librispeech

* Add context biasing for wenetspeech

* fix bugs

* Implement Aho-Corasick context graph

* fix some bugs

* Fixes to forward_one_step; add draw to context graph

* add output arc; fix black

* Fix wenetspeech tokenizer

* Minor fixes to the decode.py
This commit is contained in:
Wei Kang 2023-06-03 21:28:49 +08:00 committed by GitHub
parent ca60ced213
commit ba257efbcd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 652 additions and 26 deletions

View File

@ -58,6 +58,7 @@ Usage:
import argparse import argparse
import logging import logging
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -76,6 +77,8 @@ from beam_search import (
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall import ContextGraph, LmScorer, NgramLm
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -211,6 +214,26 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding_method is modified_beam_search.
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -222,6 +245,7 @@ def decode_one_batch(
token_table: k2.SymbolTable, token_table: k2.SymbolTable,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -285,6 +309,7 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
context_graph=context_graph,
) )
else: else:
hyp_tokens = [] hyp_tokens = []
@ -324,7 +349,12 @@ def decode_one_batch(
): hyps ): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} key = f"beam_size_{params.beam_size}"
if params.has_contexts:
key += f"-context-score-{params.context_score}"
else:
key += "-no-context-words"
return {key: hyps}
def decode_dataset( def decode_dataset(
@ -333,6 +363,7 @@ def decode_dataset(
model: nn.Module, model: nn.Module,
token_table: k2.SymbolTable, token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -377,6 +408,7 @@ def decode_dataset(
model=model, model=model,
token_table=token_table, token_table=token_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
batch=batch, batch=batch,
) )
@ -407,16 +439,17 @@ def save_results(
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) # we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
store_transcripts(filename=recog_path, texts=results_char)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -457,6 +490,12 @@ def main():
"fast_beam_search", "fast_beam_search",
"modified_beam_search", "modified_beam_search",
) )
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0: if params.iter > 0:
@ -470,6 +509,10 @@ def main():
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += "-no-contexts-words"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -490,6 +533,11 @@ def main():
params.blank_id = 0 params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
@ -586,6 +634,19 @@ def main():
else: else:
decoding_graph = None decoding_graph = None
if params.decoding_method == "modified_beam_search":
if os.path.exists(params.context_file):
contexts_text = []
for line in open(params.context_file).readlines():
contexts_text.append(line.strip())
contexts = graph_compiler.texts_to_ids(contexts_text)
context_graph = ContextGraph(params.context_score)
context_graph.build(contexts)
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -608,6 +669,7 @@ def main():
model=model, model=model,
token_table=lexicon.token_table, token_table=lexicon.token_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
) )
save_results( save_results(

View File

@ -24,7 +24,7 @@ import sentencepiece as spm
import torch import torch
from model import Transducer from model import Transducer
from icefall import NgramLm, NgramLmStateCost from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding from icefall.decode import Nbest, one_best_decoding
from icefall.lm_wrapper import LmScorer from icefall.lm_wrapper import LmScorer
from icefall.rnn_lm.model import RnnLmModel from icefall.rnn_lm.model import RnnLmModel
@ -765,6 +765,9 @@ class Hypothesis:
# N-gram LM state # N-gram LM state
state_cost: Optional[NgramLmStateCost] = None state_cost: Optional[NgramLmStateCost] = None
# Context graph state
context_state: Optional[ContextState] = None
@property @property
def key(self) -> str: def key(self) -> str:
"""Return a string representation of self.ys""" """Return a string representation of self.ys"""
@ -917,6 +920,7 @@ def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
context_graph: Optional[ContextGraph] = None,
beam: int = 4, beam: int = 4,
temperature: float = 1.0, temperature: float = 1.0,
return_timestamps: bool = False, return_timestamps: bool = False,
@ -968,6 +972,7 @@ def modified_beam_search(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
context_state=None if context_graph is None else context_graph.root,
timestamp=[], timestamp=[],
) )
) )
@ -990,6 +995,7 @@ def modified_beam_search(
hyps_shape = get_hyps_shape(B).to(device) hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B] A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)] B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat( ys_log_probs = torch.cat(
@ -1047,21 +1053,51 @@ def modified_beam_search(
for k in range(len(topk_hyp_indexes)): for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k] hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx] hyp = A[i][hyp_idx]
new_ys = hyp.ys[:] new_ys = hyp.ys[:]
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:] new_timestamp = hyp.timestamp[:]
context_score = 0
new_context_state = None if context_graph is None else hyp.context_state
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
new_ys.append(new_token) new_ys.append(new_token)
new_timestamp.append(t) new_timestamp.append(t)
if context_graph is not None:
(
context_score,
new_context_state,
) = context_graph.forward_one_step(hyp.context_state, new_token)
new_log_prob = topk_log_probs[k] + context_score
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis( new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp ys=new_ys,
log_prob=new_log_prob,
timestamp=new_timestamp,
context_state=new_context_state,
) )
B[i].add(new_hyp) B[i].add(new_hyp)
B = B + finalized_B B = B + finalized_B
# finalize context_state, if the matched contexts do not reach final state
# we need to add the score on the corresponding backoff arc
if context_graph is not None:
finalized_B = [HypothesisList() for _ in range(len(B))]
for i, hyps in enumerate(B):
for hyp in list(hyps):
context_score, new_context_state = context_graph.finalize(
hyp.context_state
)
finalized_B[i].add(
Hypothesis(
ys=hyp.ys,
log_prob=hyp.log_prob + context_score,
timestamp=hyp.timestamp,
context_state=new_context_state,
)
)
B = finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B] best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps] sorted_ans = [h.ys[context_size:] for h in best_hyps]

View File

@ -125,6 +125,7 @@ For example:
import argparse import argparse
import logging import logging
import math import math
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -146,6 +147,7 @@ from beam_search import (
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall import ContextGraph
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -353,6 +355,27 @@ def get_parser():
Used only when the decoding method is fast_beam_search_nbest, Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding_method is modified_beam_search.
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -365,6 +388,7 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]: ) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -494,6 +518,7 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
context_graph=context_graph,
return_timestamps=True, return_timestamps=True,
) )
else: else:
@ -548,7 +573,12 @@ def decode_one_batch(
return {key: (hyps, timestamps)} return {key: (hyps, timestamps)}
else: else:
return {f"beam_size_{params.beam_size}": (hyps, timestamps)} key = f"beam_size_{params.beam_size}"
if params.has_contexts:
key += f"-context-score-{params.context_score}"
else:
key += "-no-context-words"
return {key: (hyps, timestamps)}
def decode_dataset( def decode_dataset(
@ -558,6 +588,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
"""Decode dataset. """Decode dataset.
@ -622,6 +653,7 @@ def decode_dataset(
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
word_table=word_table, word_table=word_table,
batch=batch, batch=batch,
context_graph=context_graph,
) )
for name, (hyps, timestamps_hyp) in hyps_dict.items(): for name, (hyps, timestamps_hyp) in hyps_dict.items():
@ -728,6 +760,12 @@ def main():
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
) )
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0: if params.iter > 0:
@ -750,6 +788,10 @@ def main():
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += "-no-context-words"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -881,6 +923,18 @@ def main():
decoding_graph = None decoding_graph = None
word_table = None word_table = None
if params.decoding_method == "modified_beam_search":
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build(sp.encode(contexts))
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -905,6 +959,7 @@ def main():
sp=sp, sp=sp,
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
) )
save_results( save_results(

View File

@ -106,7 +106,7 @@ class WenetSpeechAsrDataModule:
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=300, default=30,
help="The number of buckets for the DynamicBucketingSampler" help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).", "(you might want to increase it for larger datasets).",
) )
@ -364,7 +364,7 @@ class WenetSpeechAsrDataModule:
return valid_dl return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.info("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats if self.args.on_the_fly_feats

View File

@ -92,7 +92,7 @@ When training with the L subset, the streaming usage:
--causal-convolution 1 \ --causal-convolution 1 \
--decode-chunk-size 16 \ --decode-chunk-size 16 \
--left-context 64 --left-context 64
(4) modified beam search with RNNLM shallow fusion (4) modified beam search with RNNLM shallow fusion
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 35 \ --epoch 35 \
@ -112,8 +112,10 @@ When training with the L subset, the streaming usage:
import argparse import argparse
import glob
import logging import logging
import math import math
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -133,7 +135,8 @@ from beam_search import (
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall import LmScorer, NgramLm from icefall import ContextGraph, LmScorer, NgramLm
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -307,6 +310,26 @@ def get_parser():
help="left context can be seen during decoding (in frames after subsampling)", help="left context can be seen during decoding (in frames after subsampling)",
) )
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument( parser.add_argument(
"--use-shallow-fusion", "--use-shallow-fusion",
type=str2bool, type=str2bool,
@ -362,6 +385,7 @@ def decode_one_batch(
lexicon: Lexicon, lexicon: Lexicon,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
ngram_lm: Optional[NgramLm] = None, ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0, ngram_lm_scale: float = 1.0,
LM: Optional[LmScorer] = None, LM: Optional[LmScorer] = None,
@ -402,14 +426,13 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
if params.simulate_streaming: if params.simulate_streaming:
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
@ -448,6 +471,7 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
beam=params.beam_size, beam=params.beam_size,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
context_graph=context_graph,
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -509,7 +533,12 @@ def decode_one_batch(
): hyps ): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} key = f"beam_size_{params.beam_size}"
if params.has_contexts:
key += f"-context-score-{params.context_score}"
else:
key += "-no-context-words"
return {key: hyps}
def decode_dataset( def decode_dataset(
@ -518,6 +547,7 @@ def decode_dataset(
model: nn.Module, model: nn.Module,
lexicon: Lexicon, lexicon: Lexicon,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
ngram_lm: Optional[NgramLm] = None, ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0, ngram_lm_scale: float = 1.0,
LM: Optional[LmScorer] = None, LM: Optional[LmScorer] = None,
@ -567,6 +597,7 @@ def decode_dataset(
lexicon=lexicon, lexicon=lexicon,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
batch=batch, batch=batch,
context_graph=context_graph,
ngram_lm=ngram_lm, ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale, ngram_lm_scale=ngram_lm_scale,
LM=LM, LM=LM,
@ -646,6 +677,12 @@ def main():
"modified_beam_search_lm_shallow_fusion", "modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR", "modified_beam_search_LODR",
) )
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
@ -655,6 +692,10 @@ def main():
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += "-no-contexts-words"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -684,11 +725,15 @@ def main():
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
# import pdb; pdb.set_trace()
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"] params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
if params.simulate_streaming: if params.simulate_streaming:
assert ( assert (
params.causal_convolution params.causal_convolution
@ -816,6 +861,19 @@ def main():
else: else:
decoding_graph = None decoding_graph = None
if params.decoding_method == "modified_beam_search":
if os.path.exists(params.context_file):
contexts_text = []
for line in open(params.context_file).readlines():
contexts_text.append(line.strip())
contexts = graph_compiler.texts_to_ids(contexts_text)
context_graph = ContextGraph(params.context_score)
context_graph.build(contexts)
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -833,15 +891,16 @@ def main():
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl] test_dls = [dev_dl, test_net_dl, test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
lexicon=lexicon, lexicon=lexicon,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
ngram_lm=ngram_lm, ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale, ngram_lm_scale=ngram_lm_scale,
LM=LM, LM=LM,

View File

@ -23,6 +23,8 @@ from .checkpoint import (
save_checkpoint_with_global_batch_idx, save_checkpoint_with_global_batch_idx,
) )
from .context_graph import ContextGraph, ContextState
from .decode import ( from .decode import (
get_lattice, get_lattice,
nbest_decoding, nbest_decoding,

412
icefall/context_graph.py Normal file
View File

@ -0,0 +1,412 @@
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from collections import deque
from typing import Dict, List, Optional, Tuple
class ContextState:
"""The state in ContextGraph"""
def __init__(
self,
id: int,
token: int,
token_score: float,
node_score: float,
local_node_score: float,
is_end: bool,
):
"""Create a ContextState.
Args:
id:
The node id, only for visualization now. A node is in [0, graph.num_nodes).
The id of the root node is always 0.
token:
The token id.
score:
The bonus for each token during decoding, which will hopefully
boost the token up to survive beam search.
node_score:
The accumulated bonus from root of graph to current node, it will be
used to calculate the score for fail arc.
local_node_score:
The accumulated bonus from last ``end_node``(node with is_end true)
to current_node, it will be used to calculate the score for fail arc.
Node: The local_node_score of a ``end_node`` is 0.
is_end:
True if current token is the end of a context.
"""
self.id = id
self.token = token
self.token_score = token_score
self.node_score = node_score
self.local_node_score = local_node_score
self.is_end = is_end
self.next = {}
self.fail = None
self.output = None
class ContextGraph:
"""The ContextGraph is modified from Aho-Corasick which is mainly
a Trie with a fail arc for each node.
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for more details
of Aho-Corasick algorithm.
A ContextGraph contains some words / phrases that we expect to boost their
scores during decoding. If the substring of a decoded sequence matches the word / phrase
in the ContextGraph, we will give the decoded sequence a bonus to make it survive
beam search.
"""
def __init__(self, context_score: float):
"""Initialize a ContextGraph with the given ``context_score``.
A root node will be created (**NOTE:** the token of root is hardcoded to -1).
Args:
context_score:
The bonus score for each token(note: NOT for each word/phrase, it means longer
word/phrase will have larger bonus score, they have to be matched though).
"""
self.context_score = context_score
self.num_nodes = 0
self.root = ContextState(
id=self.num_nodes,
token=-1,
token_score=0,
node_score=0,
local_node_score=0,
is_end=False,
)
self.root.fail = self.root
def _fill_fail_output(self):
"""This function fills the fail arc for each trie node, it can be computed
in linear time by performing a breadth-first search starting from the root.
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
details of the algorithm.
"""
queue = deque()
for token, node in self.root.next.items():
node.fail = self.root
queue.append(node)
while queue:
current_node = queue.popleft()
for token, node in current_node.next.items():
fail = current_node.fail
if token in fail.next:
fail = fail.next[token]
else:
fail = fail.fail
while token not in fail.next:
fail = fail.fail
if fail.token == -1: # root
break
if token in fail.next:
fail = fail.next[token]
node.fail = fail
# fill the output arc
output = node.fail
while not output.is_end:
output = output.fail
if output.token == -1: # root
output = None
break
node.output = output
queue.append(node)
def build(self, token_ids: List[List[int]]):
"""Build the ContextGraph from a list of token list.
It first build a trie from the given token lists, then fill the fail arc
for each trie node.
See https://en.wikipedia.org/wiki/Trie for how to build a trie.
Args:
token_ids:
The given token lists to build the ContextGraph, it is a list of token list,
each token list contains the token ids for a word/phrase. The token id
could be an id of a char (modeling with single Chinese char) or an id
of a BPE (modeling with BPEs).
"""
for tokens in token_ids:
node = self.root
for i, token in enumerate(tokens):
if token not in node.next:
self.num_nodes += 1
is_end = i == len(tokens) - 1
node.next[token] = ContextState(
id=self.num_nodes,
token=token,
token_score=self.context_score,
node_score=node.node_score + self.context_score,
local_node_score=0
if is_end
else (node.local_node_score + self.context_score),
is_end=is_end,
)
node = node.next[token]
self._fill_fail_output()
def forward_one_step(
self, state: ContextState, token: int
) -> Tuple[float, ContextState]:
"""Search the graph with given state and token.
Args:
state:
The given token containing trie node to start.
token:
The given token.
Returns:
Return a tuple of score and next state.
"""
node = None
score = 0
# token matched
if token in state.next:
node = state.next[token]
score = node.token_score
if state.is_end:
score += state.node_score
else:
# token not matched
# We will trace along the fail arc until it matches the token or reaching
# root of the graph.
node = state.fail
while token not in node.next:
node = node.fail
if node.token == -1: # root
break
if token in node.next:
node = node.next[token]
# The score of the fail path
score = node.node_score - state.local_node_score
assert node is not None
matched_score = 0
output = node.output
while output is not None:
matched_score += output.node_score
output = output.output
return (score + matched_score, node)
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
"""When reaching the end of the decoded sequence, we need to finalize
the matching, the purpose is to subtract the added bonus score for the
state that is not the end of a word/phrase.
Args:
state:
The given state(trie node).
Returns:
Return a tuple of score and next state. If state is the end of a word/phrase
the score is zero, otherwise the score is the score of a implicit fail arc
to root. The next state is always root.
"""
# The score of the fail arc
score = -state.node_score
if state.is_end:
score = 0
return (score, self.root)
def draw(
self,
title: Optional[str] = None,
filename: Optional[str] = "",
symbol_table: Optional[Dict[int, str]] = None,
) -> "Digraph": # noqa
"""Visualize a ContextGraph via graphviz.
Render ContextGraph as an image via graphviz, and return the Digraph object;
and optionally save to file `filename`.
`filename` must have a suffix that graphviz understands, such as
`pdf`, `svg` or `png`.
Note:
You need to install graphviz to use this function::
pip install graphviz
Args:
title:
Title to be displayed in image, e.g. 'A simple FSA example'
filename:
Filename to (optionally) save to, e.g. 'foo.png', 'foo.svg',
'foo.png' (must have a suffix that graphviz understands).
symbol_table:
Map the token ids to symbols.
Returns:
A Diagraph from grahpviz.
"""
try:
import graphviz
except Exception:
print("You cannot use `to_dot` unless the graphviz package is installed.")
raise
graph_attr = {
"rankdir": "LR",
"size": "8.5,11",
"center": "1",
"orientation": "Portrait",
"ranksep": "0.4",
"nodesep": "0.25",
}
if title is not None:
graph_attr["label"] = title
default_node_attr = {
"shape": "circle",
"style": "bold",
"fontsize": "14",
}
final_state_attr = {
"shape": "doublecircle",
"style": "bold",
"fontsize": "14",
}
final_state = -1
dot = graphviz.Digraph(name="Context Graph", graph_attr=graph_attr)
seen = set()
queue = deque()
queue.append(self.root)
# root id is always 0
dot.node("0", label="0", **default_node_attr)
dot.edge("0", "0", color="red")
seen.add(0)
while len(queue):
current_node = queue.popleft()
for token, node in current_node.next.items():
if node.id not in seen:
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
local_node_score = f"{node.local_node_score:.2f}".rstrip(
"0"
).rstrip(".")
label = f"{node.id}/({node_score},{local_node_score})"
if node.is_end:
dot.node(str(node.id), label=label, **final_state_attr)
else:
dot.node(str(node.id), label=label, **default_node_attr)
seen.add(node.id)
weight = f"{node.token_score:.2f}".rstrip("0").rstrip(".")
label = str(token) if symbol_table is None else symbol_table[token]
dot.edge(str(current_node.id), str(node.id), label=f"{label}/{weight}")
dot.edge(
str(node.id),
str(node.fail.id),
color="red",
)
if node.output is not None:
dot.edge(
str(node.id),
str(node.output.id),
color="green",
)
queue.append(node)
if filename:
_, extension = os.path.splitext(filename)
if extension == "" or extension[0] != ".":
raise ValueError(
"Filename needs to have a suffix like .png, .pdf, .svg: {}".format(
filename
)
)
import tempfile
with tempfile.TemporaryDirectory() as tmp_dir:
temp_fn = dot.render(
filename="temp",
directory=tmp_dir,
format=extension[1:],
cleanup=True,
)
shutil.move(temp_fn, filename)
return dot
if __name__ == "__main__":
contexts_str = [
"S",
"HE",
"SHE",
"SHELL",
"HIS",
"HERS",
"HELLO",
"THIS",
"THEM",
]
contexts = []
for s in contexts_str:
contexts.append([ord(x) for x in s])
context_graph = ContextGraph(context_score=1)
context_graph.build(contexts)
symbol_table = {}
for contexts in contexts_str:
for s in contexts:
symbol_table[ord(s)] = s
context_graph.draw(
title="Graph for: " + " / ".join(contexts_str),
filename="context_graph.pdf",
symbol_table=symbol_table,
)
queries = {
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
"HISHE": 9, # "HIS", "S", "SHE", "HE"
"SHED": 6, # "S", "SHE", "HE"
"HELL": 2, # "HE"
"HELLO": 7, # "HE", "HELLO"
"DHRHISQ": 4, # "HIS", "S"
"THEN": 2, # "HE"
}
for query, expected_score in queries.items():
total_scores = 0
state = context_graph.root
for q in query:
score, state = context_graph.forward_one_step(state, ord(q))
total_scores += score
score, state = context_graph.finalize(state)
assert state.token == -1, state.token
total_scores += score
assert total_scores == expected_score, (
total_scores,
expected_score,
query,
)