From 482efb802005cad679f35828b7d972fc1c43676e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 10 Apr 2023 12:11:29 +0800 Subject: [PATCH 1/8] support combining two LMs --- icefall/shared/combine_lm.py | 183 +++++++++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100755 icefall/shared/combine_lm.py diff --git a/icefall/shared/combine_lm.py b/icefall/shared/combine_lm.py new file mode 100755 index 000000000..b6b7c8e95 --- /dev/null +++ b/icefall/shared/combine_lm.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# flake8: noqa + +# Copyright 2023 Xiaomi Corp. (author: Fangjun Kuang) +# +""" +Given two LMs "A" and "B", this script modifies +probabilities in A such that + +P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) / P_{B}(w_n|w0,w1,...,w_{n-1}) + +When it is formulated in log-space, it becomes + +\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - \log P_{B}(w_n|w0,w1,...,w_{n-1}) + +Optionally, you can pass a scale for the LM "B", such that + +\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - scale * \log P_{B}(w_n|w0,w1,...,w_{n-1}) + +Usage: + + python3 ./combine_lm.py \ + --a 4-gram.arpa \ + --b 2-gram.arpa \ + --b-scale 1.0 \ + --out new-4-gram.arpa + +It will generate a new arpa file `new-4-gram.arpa` +""" +import logging +import re +from typing import List + +try: + import kenlm +except ImportError: + print("Please install kenlm first. You can use") + print() + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + +import argparse +from pathlib import Path + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--a", + type=str, + required=True, + help="Path to the first LM. Its order is usually higher than that of LM b", + ) + + parser.add_argument( + "--b", + type=str, + required=True, + help="Path to the second LM. Its order is usually lower than that of LM b", + ) + + parser.add_argument( + "--b-scale", + type=float, + default=1.0, + help="Scale for the second LM.", + ) + + parser.add_argument( + "--out", + type=str, + required=True, + help="Path to save the generate LM.", + ), + + return parser.parse_args() + + +def check_args(args): + assert Path(args.a).is_file(), f"{args.a} does not exist" + assert Path(args.b).is_file(), f"{args.b} does not exist" + + +def get_score(model: kenlm.Model, history: List[str], word: str): + """Compute \log_{10} p(word|history). + + If history is [w0, w1, w2] and word is w3, the function returns + p(w3|w0,w1,w2) + + Caution: + The returned score is in log10. + + Args: + model: + The kenLM model. + history: + The history words. + word: + The current word. + Returns: + Return \log_{10} p(word|history). + """ + order = model.order + history = history[-(order - 1) :] if order > 1 else history + + in_state = kenlm.State() + out_state = kenlm.State() + model.NullContextWrite(in_state) + + for w in history: + model.BaseScore(in_state, w, out_state) + in_state, out_state = out_state, in_state + + return model.BaseScore(in_state, word, out_state) + + +def _process_grams( + a: "_io.TextIOWrapper", + b: kenlm.Model, + b_scale: float, + order: int, + out: "_io.TextIOWrapper", +): + for line in a: + line = line.strip() + if not line: + print("", file=out) + break + + s = line.strip().split() + assert len(s) > order, len(s) + assert len(s) >= order + 1, len(s) + assert len(s) <= order + 2, len(s) + + log10_p_a = float(s[0]) + history = s[1:order] + word = s[order] + + log10_p_b = get_score(b, history, word) + if log10_p_a < b_scale * log10_p_b: + log10_p_a -= b_scale * log10_p_b + + print(f"{log10_p_a:.7f}", end="\t", file=out) + print("\t".join(s[1:]), file=out) + + +def process(args): + b = kenlm.LanguageModel(args.b) + logging.info(f"Order of {args.b}: {b.order}") + pattern = re.compile(r"\\(\d+)-grams:") + out = open(args.out, "w", encoding="utf-8") + + b_scale = args.b_scale + + with open(args.a, encoding="utf-8") as a: + for line in a: + print(line, end="", file=out) + m = pattern.search(line) + if m: + order = int(m.group(1)) + _process_grams(a, b, b_scale=b_scale, order=order, out=out) + out.close() + + +def main(): + args = get_args() + logging.info(vars(args)) + check_args(args) + + process(args) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() From c3c6d4c3c40e97030f2aa1f3c9e4c1d37c0a1d88 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 10 Apr 2023 12:13:36 +0800 Subject: [PATCH 2/8] Support combining two LMs --- icefall/shared/combine_lm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/icefall/shared/combine_lm.py b/icefall/shared/combine_lm.py index b6b7c8e95..d443db921 100755 --- a/icefall/shared/combine_lm.py +++ b/icefall/shared/combine_lm.py @@ -7,11 +7,11 @@ Given two LMs "A" and "B", this script modifies probabilities in A such that -P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) / P_{B}(w_n|w0,w1,...,w_{n-1}) +P_{A_{new}}(w_n|w_0,w_1,...,w_{n-1}) = P_{A_{original}}(w_n|w_0,w_1,...,w_{n-1}) / P_{B}(w_n|w_0,w_1,...,w_{n-1}) When it is formulated in log-space, it becomes -\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - \log P_{B}(w_n|w0,w1,...,w_{n-1}) +\log P_{A_{new}}(w_n|w_0,w_1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w_0,w_1,...,w_{n-1}) - \log P_{B}(w_n|w_0,w_1,...,w_{n-1}) Optionally, you can pass a scale for the LM "B", such that From c91e7e2fac5d5f99455f652fe19026e0304be422 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 10 Apr 2023 12:17:02 +0800 Subject: [PATCH 3/8] small fixes --- icefall/shared/combine_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/shared/combine_lm.py b/icefall/shared/combine_lm.py index d443db921..82a2d4005 100755 --- a/icefall/shared/combine_lm.py +++ b/icefall/shared/combine_lm.py @@ -107,7 +107,7 @@ def get_score(model: kenlm.Model, history: List[str], word: str): Return \log_{10} p(word|history). """ order = model.order - history = history[-(order - 1) :] if order > 1 else history + history = history[-(order - 1) :] in_state = kenlm.State() out_state = kenlm.State() From 9e338a8e36a53722846fbcb34d12be46b42e8d03 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 10 Apr 2023 12:24:20 +0800 Subject: [PATCH 4/8] Add a scale for LM A --- icefall/shared/combine_lm.py | 43 +++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/icefall/shared/combine_lm.py b/icefall/shared/combine_lm.py index 82a2d4005..12038bf91 100755 --- a/icefall/shared/combine_lm.py +++ b/icefall/shared/combine_lm.py @@ -13,15 +13,16 @@ When it is formulated in log-space, it becomes \log P_{A_{new}}(w_n|w_0,w_1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w_0,w_1,...,w_{n-1}) - \log P_{B}(w_n|w_0,w_1,...,w_{n-1}) -Optionally, you can pass a scale for the LM "B", such that +Optionally, you can pass scales for LM "A" and LM "B", such that -\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - scale * \log P_{B}(w_n|w0,w1,...,w_{n-1}) +\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = a_scale * \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - b_scale * \log P_{B}(w_n|w0,w1,...,w_{n-1}) Usage: python3 ./combine_lm.py \ --a 4-gram.arpa \ --b 2-gram.arpa \ + --a-scale 1.0 \ --b-scale 1.0 \ --out new-4-gram.arpa @@ -65,6 +66,13 @@ def get_args(): help="Path to the second LM. Its order is usually lower than that of LM b", ) + parser.add_argument( + "--a-scale", + type=float, + default=1.0, + help="Scale for the first LM.", + ) + parser.add_argument( "--b-scale", type=float, @@ -123,10 +131,26 @@ def get_score(model: kenlm.Model, history: List[str], word: str): def _process_grams( a: "_io.TextIOWrapper", b: kenlm.Model, + a_scale: float, b_scale: float, order: int, out: "_io.TextIOWrapper", ): + """ + Args: + a: + A file handle for the LM "A" + b: + LM B. + a_scale: + The scale for scores from LM A. + b_scale: + The scale for scores from LM B. + order: int + Current order of LM A. + out: + File handle for the output LM. + """ for line in a: line = line.strip() if not line: @@ -143,8 +167,9 @@ def _process_grams( word = s[order] log10_p_b = get_score(b, history, word) - if log10_p_a < b_scale * log10_p_b: - log10_p_a -= b_scale * log10_p_b + if a_scale * log10_p_a < b_scale * log10_p_b: + # ensure that the resulting log10_p_a is negative + log10_p_a = a_scale * log10_p_a - b_scale * log10_p_b print(f"{log10_p_a:.7f}", end="\t", file=out) print("\t".join(s[1:]), file=out) @@ -156,6 +181,7 @@ def process(args): pattern = re.compile(r"\\(\d+)-grams:") out = open(args.out, "w", encoding="utf-8") + a_scale = args.a_scale b_scale = args.b_scale with open(args.a, encoding="utf-8") as a: @@ -164,7 +190,14 @@ def process(args): m = pattern.search(line) if m: order = int(m.group(1)) - _process_grams(a, b, b_scale=b_scale, order=order, out=out) + _process_grams( + a=a, + b=b, + a_scale=a_scale, + b_scale=b_scale, + order=order, + out=out, + ) out.close() From 17465143483a44dbb4d735e981c2105daacc91f4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 10 Apr 2023 12:29:08 +0800 Subject: [PATCH 5/8] fix typos --- icefall/shared/combine_lm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/icefall/shared/combine_lm.py b/icefall/shared/combine_lm.py index 12038bf91..8e117a4ae 100755 --- a/icefall/shared/combine_lm.py +++ b/icefall/shared/combine_lm.py @@ -38,6 +38,7 @@ except ImportError: print("Please install kenlm first. You can use") print() print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print() print("to install it") import sys @@ -63,7 +64,7 @@ def get_args(): "--b", type=str, required=True, - help="Path to the second LM. Its order is usually lower than that of LM b", + help="Path to the second LM. Its order is usually lower than that of LM a", ) parser.add_argument( @@ -84,7 +85,7 @@ def get_args(): "--out", type=str, required=True, - help="Path to save the generate LM.", + help="Path to save the generated LM.", ), return parser.parse_args() From b64fb60a0d888c04b039d73b393ce7f498b1c9e9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 10 Apr 2023 12:34:34 +0800 Subject: [PATCH 6/8] small fixes --- icefall/shared/combine_lm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/icefall/shared/combine_lm.py b/icefall/shared/combine_lm.py index 8e117a4ae..9c912006f 100755 --- a/icefall/shared/combine_lm.py +++ b/icefall/shared/combine_lm.py @@ -179,7 +179,7 @@ def _process_grams( def process(args): b = kenlm.LanguageModel(args.b) logging.info(f"Order of {args.b}: {b.order}") - pattern = re.compile(r"\\(\d+)-grams:") + pattern = re.compile(r"\\(\d+)-grams:\n") out = open(args.out, "w", encoding="utf-8") a_scale = args.a_scale @@ -188,7 +188,7 @@ def process(args): with open(args.a, encoding="utf-8") as a: for line in a: print(line, end="", file=out) - m = pattern.search(line) + m = pattern.match(line) if m: order = int(m.group(1)) _process_grams( From 0b1492bbf891de4f89a009477f986947fb4ff6e0 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 10 Apr 2023 14:36:37 +0800 Subject: [PATCH 7/8] Also scale down the backoff prob --- icefall/shared/combine_lm.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/icefall/shared/combine_lm.py b/icefall/shared/combine_lm.py index 9c912006f..aed55baaa 100755 --- a/icefall/shared/combine_lm.py +++ b/icefall/shared/combine_lm.py @@ -167,13 +167,23 @@ def _process_grams( history = s[1:order] word = s[order] + log10_p_a_backoff = 0 if len(s) < order + 2 else float(s[-1]) + log10_p_b = get_score(b, history, word) if a_scale * log10_p_a < b_scale * log10_p_b: # ensure that the resulting log10_p_a is negative log10_p_a = a_scale * log10_p_a - b_scale * log10_p_b + else: + log10_p_a *= a_scale + + log10_p_a_backoff *= a_scale print(f"{log10_p_a:.7f}", end="\t", file=out) - print("\t".join(s[1:]), file=out) + if len(s) < order + 2: + print("\t".join(s[1:]), file=out) + else: + print("\t".join(s[1:-1]), end="\t", file=out) + print(f"{log10_p_a_backoff:.7f}", file=out) def process(args): From b71f0428f52d3a50151a0bdfbcef366c1367f235 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 10 Apr 2023 14:38:55 +0800 Subject: [PATCH 8/8] small fixes --- icefall/shared/combine_lm.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/icefall/shared/combine_lm.py b/icefall/shared/combine_lm.py index aed55baaa..d7b85e1b4 100755 --- a/icefall/shared/combine_lm.py +++ b/icefall/shared/combine_lm.py @@ -170,11 +170,9 @@ def _process_grams( log10_p_a_backoff = 0 if len(s) < order + 2 else float(s[-1]) log10_p_b = get_score(b, history, word) - if a_scale * log10_p_a < b_scale * log10_p_b: - # ensure that the resulting log10_p_a is negative - log10_p_a = a_scale * log10_p_a - b_scale * log10_p_b - else: - log10_p_a *= a_scale + + # ensure that the resulting log10_p_a is not positive + log10_p_a = min(0, a_scale * log10_p_a - b_scale * log10_p_b) log10_p_a_backoff *= a_scale