diff --git a/icefall/shared/combine_lm.py b/icefall/shared/combine_lm.py new file mode 100755 index 000000000..b6b7c8e95 --- /dev/null +++ b/icefall/shared/combine_lm.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# flake8: noqa + +# Copyright 2023 Xiaomi Corp. (author: Fangjun Kuang) +# +""" +Given two LMs "A" and "B", this script modifies +probabilities in A such that + +P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) / P_{B}(w_n|w0,w1,...,w_{n-1}) + +When it is formulated in log-space, it becomes + +\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - \log P_{B}(w_n|w0,w1,...,w_{n-1}) + +Optionally, you can pass a scale for the LM "B", such that + +\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - scale * \log P_{B}(w_n|w0,w1,...,w_{n-1}) + +Usage: + + python3 ./combine_lm.py \ + --a 4-gram.arpa \ + --b 2-gram.arpa \ + --b-scale 1.0 \ + --out new-4-gram.arpa + +It will generate a new arpa file `new-4-gram.arpa` +""" +import logging +import re +from typing import List + +try: + import kenlm +except ImportError: + print("Please install kenlm first. You can use") + print() + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + +import argparse +from pathlib import Path + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--a", + type=str, + required=True, + help="Path to the first LM. Its order is usually higher than that of LM b", + ) + + parser.add_argument( + "--b", + type=str, + required=True, + help="Path to the second LM. Its order is usually lower than that of LM b", + ) + + parser.add_argument( + "--b-scale", + type=float, + default=1.0, + help="Scale for the second LM.", + ) + + parser.add_argument( + "--out", + type=str, + required=True, + help="Path to save the generate LM.", + ), + + return parser.parse_args() + + +def check_args(args): + assert Path(args.a).is_file(), f"{args.a} does not exist" + assert Path(args.b).is_file(), f"{args.b} does not exist" + + +def get_score(model: kenlm.Model, history: List[str], word: str): + """Compute \log_{10} p(word|history). + + If history is [w0, w1, w2] and word is w3, the function returns + p(w3|w0,w1,w2) + + Caution: + The returned score is in log10. + + Args: + model: + The kenLM model. + history: + The history words. + word: + The current word. + Returns: + Return \log_{10} p(word|history). + """ + order = model.order + history = history[-(order - 1) :] if order > 1 else history + + in_state = kenlm.State() + out_state = kenlm.State() + model.NullContextWrite(in_state) + + for w in history: + model.BaseScore(in_state, w, out_state) + in_state, out_state = out_state, in_state + + return model.BaseScore(in_state, word, out_state) + + +def _process_grams( + a: "_io.TextIOWrapper", + b: kenlm.Model, + b_scale: float, + order: int, + out: "_io.TextIOWrapper", +): + for line in a: + line = line.strip() + if not line: + print("", file=out) + break + + s = line.strip().split() + assert len(s) > order, len(s) + assert len(s) >= order + 1, len(s) + assert len(s) <= order + 2, len(s) + + log10_p_a = float(s[0]) + history = s[1:order] + word = s[order] + + log10_p_b = get_score(b, history, word) + if log10_p_a < b_scale * log10_p_b: + log10_p_a -= b_scale * log10_p_b + + print(f"{log10_p_a:.7f}", end="\t", file=out) + print("\t".join(s[1:]), file=out) + + +def process(args): + b = kenlm.LanguageModel(args.b) + logging.info(f"Order of {args.b}: {b.order}") + pattern = re.compile(r"\\(\d+)-grams:") + out = open(args.out, "w", encoding="utf-8") + + b_scale = args.b_scale + + with open(args.a, encoding="utf-8") as a: + for line in a: + print(line, end="", file=out) + m = pattern.search(line) + if m: + order = int(m.group(1)) + _process_grams(a, b, b_scale=b_scale, order=order, out=out) + out.close() + + +def main(): + args = get_args() + logging.info(vars(args)) + check_args(args) + + process(args) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main()