mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
Add a scale for LM A
This commit is contained in:
parent
c91e7e2fac
commit
9e338a8e36
@ -13,15 +13,16 @@ When it is formulated in log-space, it becomes
|
||||
|
||||
\log P_{A_{new}}(w_n|w_0,w_1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w_0,w_1,...,w_{n-1}) - \log P_{B}(w_n|w_0,w_1,...,w_{n-1})
|
||||
|
||||
Optionally, you can pass a scale for the LM "B", such that
|
||||
Optionally, you can pass scales for LM "A" and LM "B", such that
|
||||
|
||||
\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - scale * \log P_{B}(w_n|w0,w1,...,w_{n-1})
|
||||
\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = a_scale * \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - b_scale * \log P_{B}(w_n|w0,w1,...,w_{n-1})
|
||||
|
||||
Usage:
|
||||
|
||||
python3 ./combine_lm.py \
|
||||
--a 4-gram.arpa \
|
||||
--b 2-gram.arpa \
|
||||
--a-scale 1.0 \
|
||||
--b-scale 1.0 \
|
||||
--out new-4-gram.arpa
|
||||
|
||||
@ -65,6 +66,13 @@ def get_args():
|
||||
help="Path to the second LM. Its order is usually lower than that of LM b",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--a-scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Scale for the first LM.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--b-scale",
|
||||
type=float,
|
||||
@ -123,10 +131,26 @@ def get_score(model: kenlm.Model, history: List[str], word: str):
|
||||
def _process_grams(
|
||||
a: "_io.TextIOWrapper",
|
||||
b: kenlm.Model,
|
||||
a_scale: float,
|
||||
b_scale: float,
|
||||
order: int,
|
||||
out: "_io.TextIOWrapper",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
a:
|
||||
A file handle for the LM "A"
|
||||
b:
|
||||
LM B.
|
||||
a_scale:
|
||||
The scale for scores from LM A.
|
||||
b_scale:
|
||||
The scale for scores from LM B.
|
||||
order: int
|
||||
Current order of LM A.
|
||||
out:
|
||||
File handle for the output LM.
|
||||
"""
|
||||
for line in a:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
@ -143,8 +167,9 @@ def _process_grams(
|
||||
word = s[order]
|
||||
|
||||
log10_p_b = get_score(b, history, word)
|
||||
if log10_p_a < b_scale * log10_p_b:
|
||||
log10_p_a -= b_scale * log10_p_b
|
||||
if a_scale * log10_p_a < b_scale * log10_p_b:
|
||||
# ensure that the resulting log10_p_a is negative
|
||||
log10_p_a = a_scale * log10_p_a - b_scale * log10_p_b
|
||||
|
||||
print(f"{log10_p_a:.7f}", end="\t", file=out)
|
||||
print("\t".join(s[1:]), file=out)
|
||||
@ -156,6 +181,7 @@ def process(args):
|
||||
pattern = re.compile(r"\\(\d+)-grams:")
|
||||
out = open(args.out, "w", encoding="utf-8")
|
||||
|
||||
a_scale = args.a_scale
|
||||
b_scale = args.b_scale
|
||||
|
||||
with open(args.a, encoding="utf-8") as a:
|
||||
@ -164,7 +190,14 @@ def process(args):
|
||||
m = pattern.search(line)
|
||||
if m:
|
||||
order = int(m.group(1))
|
||||
_process_grams(a, b, b_scale=b_scale, order=order, out=out)
|
||||
_process_grams(
|
||||
a=a,
|
||||
b=b,
|
||||
a_scale=a_scale,
|
||||
b_scale=b_scale,
|
||||
order=order,
|
||||
out=out,
|
||||
)
|
||||
out.close()
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user