mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
support combining two LMs
This commit is contained in:
parent
6434c8eadc
commit
482efb8020
183
icefall/shared/combine_lm.py
Executable file
183
icefall/shared/combine_lm.py
Executable file
@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python3
|
||||
# flake8: noqa
|
||||
|
||||
# Copyright 2023 Xiaomi Corp. (author: Fangjun Kuang)
|
||||
#
|
||||
"""
|
||||
Given two LMs "A" and "B", this script modifies
|
||||
probabilities in A such that
|
||||
|
||||
P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) / P_{B}(w_n|w0,w1,...,w_{n-1})
|
||||
|
||||
When it is formulated in log-space, it becomes
|
||||
|
||||
\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - \log P_{B}(w_n|w0,w1,...,w_{n-1})
|
||||
|
||||
Optionally, you can pass a scale for the LM "B", such that
|
||||
|
||||
\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - scale * \log P_{B}(w_n|w0,w1,...,w_{n-1})
|
||||
|
||||
Usage:
|
||||
|
||||
python3 ./combine_lm.py \
|
||||
--a 4-gram.arpa \
|
||||
--b 2-gram.arpa \
|
||||
--b-scale 1.0 \
|
||||
--out new-4-gram.arpa
|
||||
|
||||
It will generate a new arpa file `new-4-gram.arpa`
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
try:
|
||||
import kenlm
|
||||
except ImportError:
|
||||
print("Please install kenlm first. You can use")
|
||||
print()
|
||||
print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
|
||||
print("to install it")
|
||||
import sys
|
||||
|
||||
sys.exit(-1)
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--a",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the first LM. Its order is usually higher than that of LM b",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--b",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the second LM. Its order is usually lower than that of LM b",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--b-scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Scale for the second LM.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--out",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to save the generate LM.",
|
||||
),
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def check_args(args):
|
||||
assert Path(args.a).is_file(), f"{args.a} does not exist"
|
||||
assert Path(args.b).is_file(), f"{args.b} does not exist"
|
||||
|
||||
|
||||
def get_score(model: kenlm.Model, history: List[str], word: str):
|
||||
"""Compute \log_{10} p(word|history).
|
||||
|
||||
If history is [w0, w1, w2] and word is w3, the function returns
|
||||
p(w3|w0,w1,w2)
|
||||
|
||||
Caution:
|
||||
The returned score is in log10.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The kenLM model.
|
||||
history:
|
||||
The history words.
|
||||
word:
|
||||
The current word.
|
||||
Returns:
|
||||
Return \log_{10} p(word|history).
|
||||
"""
|
||||
order = model.order
|
||||
history = history[-(order - 1) :] if order > 1 else history
|
||||
|
||||
in_state = kenlm.State()
|
||||
out_state = kenlm.State()
|
||||
model.NullContextWrite(in_state)
|
||||
|
||||
for w in history:
|
||||
model.BaseScore(in_state, w, out_state)
|
||||
in_state, out_state = out_state, in_state
|
||||
|
||||
return model.BaseScore(in_state, word, out_state)
|
||||
|
||||
|
||||
def _process_grams(
|
||||
a: "_io.TextIOWrapper",
|
||||
b: kenlm.Model,
|
||||
b_scale: float,
|
||||
order: int,
|
||||
out: "_io.TextIOWrapper",
|
||||
):
|
||||
for line in a:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
print("", file=out)
|
||||
break
|
||||
|
||||
s = line.strip().split()
|
||||
assert len(s) > order, len(s)
|
||||
assert len(s) >= order + 1, len(s)
|
||||
assert len(s) <= order + 2, len(s)
|
||||
|
||||
log10_p_a = float(s[0])
|
||||
history = s[1:order]
|
||||
word = s[order]
|
||||
|
||||
log10_p_b = get_score(b, history, word)
|
||||
if log10_p_a < b_scale * log10_p_b:
|
||||
log10_p_a -= b_scale * log10_p_b
|
||||
|
||||
print(f"{log10_p_a:.7f}", end="\t", file=out)
|
||||
print("\t".join(s[1:]), file=out)
|
||||
|
||||
|
||||
def process(args):
|
||||
b = kenlm.LanguageModel(args.b)
|
||||
logging.info(f"Order of {args.b}: {b.order}")
|
||||
pattern = re.compile(r"\\(\d+)-grams:")
|
||||
out = open(args.out, "w", encoding="utf-8")
|
||||
|
||||
b_scale = args.b_scale
|
||||
|
||||
with open(args.a, encoding="utf-8") as a:
|
||||
for line in a:
|
||||
print(line, end="", file=out)
|
||||
m = pattern.search(line)
|
||||
if m:
|
||||
order = int(m.group(1))
|
||||
_process_grams(a, b, b_scale=b_scale, order=order, out=out)
|
||||
out.close()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
check_args(args)
|
||||
|
||||
process(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user