From 9e338a8e36a53722846fbcb34d12be46b42e8d03 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 10 Apr 2023 12:24:20 +0800 Subject: [PATCH] Add a scale for LM A --- icefall/shared/combine_lm.py | 43 +++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/icefall/shared/combine_lm.py b/icefall/shared/combine_lm.py index 82a2d4005..12038bf91 100755 --- a/icefall/shared/combine_lm.py +++ b/icefall/shared/combine_lm.py @@ -13,15 +13,16 @@ When it is formulated in log-space, it becomes \log P_{A_{new}}(w_n|w_0,w_1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w_0,w_1,...,w_{n-1}) - \log P_{B}(w_n|w_0,w_1,...,w_{n-1}) -Optionally, you can pass a scale for the LM "B", such that +Optionally, you can pass scales for LM "A" and LM "B", such that -\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - scale * \log P_{B}(w_n|w0,w1,...,w_{n-1}) +\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = a_scale * \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - b_scale * \log P_{B}(w_n|w0,w1,...,w_{n-1}) Usage: python3 ./combine_lm.py \ --a 4-gram.arpa \ --b 2-gram.arpa \ + --a-scale 1.0 \ --b-scale 1.0 \ --out new-4-gram.arpa @@ -65,6 +66,13 @@ def get_args(): help="Path to the second LM. Its order is usually lower than that of LM b", ) + parser.add_argument( + "--a-scale", + type=float, + default=1.0, + help="Scale for the first LM.", + ) + parser.add_argument( "--b-scale", type=float, @@ -123,10 +131,26 @@ def get_score(model: kenlm.Model, history: List[str], word: str): def _process_grams( a: "_io.TextIOWrapper", b: kenlm.Model, + a_scale: float, b_scale: float, order: int, out: "_io.TextIOWrapper", ): + """ + Args: + a: + A file handle for the LM "A" + b: + LM B. + a_scale: + The scale for scores from LM A. + b_scale: + The scale for scores from LM B. + order: int + Current order of LM A. + out: + File handle for the output LM. + """ for line in a: line = line.strip() if not line: @@ -143,8 +167,9 @@ def _process_grams( word = s[order] log10_p_b = get_score(b, history, word) - if log10_p_a < b_scale * log10_p_b: - log10_p_a -= b_scale * log10_p_b + if a_scale * log10_p_a < b_scale * log10_p_b: + # ensure that the resulting log10_p_a is negative + log10_p_a = a_scale * log10_p_a - b_scale * log10_p_b print(f"{log10_p_a:.7f}", end="\t", file=out) print("\t".join(s[1:]), file=out) @@ -156,6 +181,7 @@ def process(args): pattern = re.compile(r"\\(\d+)-grams:") out = open(args.out, "w", encoding="utf-8") + a_scale = args.a_scale b_scale = args.b_scale with open(args.a, encoding="utf-8") as a: @@ -164,7 +190,14 @@ def process(args): m = pattern.search(line) if m: order = int(m.group(1)) - _process_grams(a, b, b_scale=b_scale, order=order, out=out) + _process_grams( + a=a, + b=b, + a_scale=a_scale, + b_scale=b_scale, + order=order, + out=out, + ) out.close()