mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Add a scale for LM A
This commit is contained in:
parent
c91e7e2fac
commit
9e338a8e36
@ -13,15 +13,16 @@ When it is formulated in log-space, it becomes
|
|||||||
|
|
||||||
\log P_{A_{new}}(w_n|w_0,w_1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w_0,w_1,...,w_{n-1}) - \log P_{B}(w_n|w_0,w_1,...,w_{n-1})
|
\log P_{A_{new}}(w_n|w_0,w_1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w_0,w_1,...,w_{n-1}) - \log P_{B}(w_n|w_0,w_1,...,w_{n-1})
|
||||||
|
|
||||||
Optionally, you can pass a scale for the LM "B", such that
|
Optionally, you can pass scales for LM "A" and LM "B", such that
|
||||||
|
|
||||||
\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - scale * \log P_{B}(w_n|w0,w1,...,w_{n-1})
|
\log P_{A_{new}}(w_n|w0,w1,...,w_{n-1}) = a_scale * \log P_{A_{original}}(w_n|w0,w1,...,w_{n-1}) - b_scale * \log P_{B}(w_n|w0,w1,...,w_{n-1})
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
python3 ./combine_lm.py \
|
python3 ./combine_lm.py \
|
||||||
--a 4-gram.arpa \
|
--a 4-gram.arpa \
|
||||||
--b 2-gram.arpa \
|
--b 2-gram.arpa \
|
||||||
|
--a-scale 1.0 \
|
||||||
--b-scale 1.0 \
|
--b-scale 1.0 \
|
||||||
--out new-4-gram.arpa
|
--out new-4-gram.arpa
|
||||||
|
|
||||||
@ -65,6 +66,13 @@ def get_args():
|
|||||||
help="Path to the second LM. Its order is usually lower than that of LM b",
|
help="Path to the second LM. Its order is usually lower than that of LM b",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--a-scale",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Scale for the first LM.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--b-scale",
|
"--b-scale",
|
||||||
type=float,
|
type=float,
|
||||||
@ -123,10 +131,26 @@ def get_score(model: kenlm.Model, history: List[str], word: str):
|
|||||||
def _process_grams(
|
def _process_grams(
|
||||||
a: "_io.TextIOWrapper",
|
a: "_io.TextIOWrapper",
|
||||||
b: kenlm.Model,
|
b: kenlm.Model,
|
||||||
|
a_scale: float,
|
||||||
b_scale: float,
|
b_scale: float,
|
||||||
order: int,
|
order: int,
|
||||||
out: "_io.TextIOWrapper",
|
out: "_io.TextIOWrapper",
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
a:
|
||||||
|
A file handle for the LM "A"
|
||||||
|
b:
|
||||||
|
LM B.
|
||||||
|
a_scale:
|
||||||
|
The scale for scores from LM A.
|
||||||
|
b_scale:
|
||||||
|
The scale for scores from LM B.
|
||||||
|
order: int
|
||||||
|
Current order of LM A.
|
||||||
|
out:
|
||||||
|
File handle for the output LM.
|
||||||
|
"""
|
||||||
for line in a:
|
for line in a:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if not line:
|
if not line:
|
||||||
@ -143,8 +167,9 @@ def _process_grams(
|
|||||||
word = s[order]
|
word = s[order]
|
||||||
|
|
||||||
log10_p_b = get_score(b, history, word)
|
log10_p_b = get_score(b, history, word)
|
||||||
if log10_p_a < b_scale * log10_p_b:
|
if a_scale * log10_p_a < b_scale * log10_p_b:
|
||||||
log10_p_a -= b_scale * log10_p_b
|
# ensure that the resulting log10_p_a is negative
|
||||||
|
log10_p_a = a_scale * log10_p_a - b_scale * log10_p_b
|
||||||
|
|
||||||
print(f"{log10_p_a:.7f}", end="\t", file=out)
|
print(f"{log10_p_a:.7f}", end="\t", file=out)
|
||||||
print("\t".join(s[1:]), file=out)
|
print("\t".join(s[1:]), file=out)
|
||||||
@ -156,6 +181,7 @@ def process(args):
|
|||||||
pattern = re.compile(r"\\(\d+)-grams:")
|
pattern = re.compile(r"\\(\d+)-grams:")
|
||||||
out = open(args.out, "w", encoding="utf-8")
|
out = open(args.out, "w", encoding="utf-8")
|
||||||
|
|
||||||
|
a_scale = args.a_scale
|
||||||
b_scale = args.b_scale
|
b_scale = args.b_scale
|
||||||
|
|
||||||
with open(args.a, encoding="utf-8") as a:
|
with open(args.a, encoding="utf-8") as a:
|
||||||
@ -164,7 +190,14 @@ def process(args):
|
|||||||
m = pattern.search(line)
|
m = pattern.search(line)
|
||||||
if m:
|
if m:
|
||||||
order = int(m.group(1))
|
order = int(m.group(1))
|
||||||
_process_grams(a, b, b_scale=b_scale, order=order, out=out)
|
_process_grams(
|
||||||
|
a=a,
|
||||||
|
b=b,
|
||||||
|
a_scale=a_scale,
|
||||||
|
b_scale=b_scale,
|
||||||
|
order=order,
|
||||||
|
out=out,
|
||||||
|
)
|
||||||
out.close()
|
out.close()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user