Add a scale for LM A

This commit is contained in:
Fangjun Kuang 2023-04-10 12:24:20 +08:00
parent c91e7e2fac
commit 9e338a8e36

View File

@ -13,15 +13,16 @@ When it is formulated in log-space, it becomes
\log P_{A_{new}}(w_n|w_0,w_1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w_0,w_1,...,w_{n-1}) - \log P_{B}(w_n|w_0,w_1,...,w_{n-1})
Optionally, you can pass a scale for the LM "B", such that
Optionally, you can pass scales for LM "A" and LM "B", such that
\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - scale * \log P_{B}(w_n|w0,w1,...,w_{n-1})
\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = a_scale * \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - b_scale * \log P_{B}(w_n|w0,w1,...,w_{n-1})
Usage:
python3 ./combine_lm.py \
--a 4-gram.arpa \
--b 2-gram.arpa \
--a-scale 1.0 \
--b-scale 1.0 \
--out new-4-gram.arpa
@ -65,6 +66,13 @@ def get_args():
help="Path to the second LM. Its order is usually lower than that of LM b",
)
parser.add_argument(
"--a-scale",
type=float,
default=1.0,
help="Scale for the first LM.",
)
parser.add_argument(
"--b-scale",
type=float,
@ -123,10 +131,26 @@ def get_score(model: kenlm.Model, history: List[str], word: str):
def _process_grams(
a: "_io.TextIOWrapper",
b: kenlm.Model,
a_scale: float,
b_scale: float,
order: int,
out: "_io.TextIOWrapper",
):
"""
Args:
a:
A file handle for the LM "A"
b:
LM B.
a_scale:
The scale for scores from LM A.
b_scale:
The scale for scores from LM B.
order: int
Current order of LM A.
out:
File handle for the output LM.
"""
for line in a:
line = line.strip()
if not line:
@ -143,8 +167,9 @@ def _process_grams(
word = s[order]
log10_p_b = get_score(b, history, word)
if log10_p_a < b_scale * log10_p_b:
log10_p_a -= b_scale * log10_p_b
if a_scale * log10_p_a < b_scale * log10_p_b:
# ensure that the resulting log10_p_a is negative
log10_p_a = a_scale * log10_p_a - b_scale * log10_p_b
print(f"{log10_p_a:.7f}", end="\t", file=out)
print("\t".join(s[1:]), file=out)
@ -156,6 +181,7 @@ def process(args):
pattern = re.compile(r"\\(\d+)-grams:")
out = open(args.out, "w", encoding="utf-8")
a_scale = args.a_scale
b_scale = args.b_scale
with open(args.a, encoding="utf-8") as a:
@ -164,7 +190,14 @@ def process(args):
m = pattern.search(line)
if m:
order = int(m.group(1))
_process_grams(a, b, b_scale=b_scale, order=order, out=out)
_process_grams(
a=a,
b=b,
a_scale=a_scale,
b_scale=b_scale,
order=order,
out=out,
)
out.close()