Keep disambig tokens and backoff arcs in LG.

This commit is contained in:
Fangjun Kuang 2022-02-10 20:28:59 +08:00
parent 954b4efff3
commit 5af23efa69
5 changed files with 152 additions and 84 deletions

View File

@ -105,31 +105,15 @@ def compile_LG(lang_dir: str) -> k2.Fsa:
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set LG.properties to None
LG.__dict__["_properties"] = None
logging.info("Removing epsilons")
LG = k2.remove_epsilon(LG)
logging.info(
f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Connecting")
LG = k2.connect(LG)
logging.info(
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
logging.info(f"LG properties: {LG.properties_str}")
# Possible properties is:
# "Valid|Nonempty|ArcSorted|EpsilonFree|MaybeAccessible|MaybeCoaccessible"
logging.info("Caution: LG is not deterministic!!!")
# "Valid|Nonempty|ArcSorted|ArcSortedAndDeterministic|EpsilonFree|MaybeAccessible|MaybeCoaccessible" # noqa
logging.info(
"Caution: LG is deterministic and contains disambig symbols!!!"
)
return LG

View File

@ -21,58 +21,75 @@ To run this file, do:
python ./local/test_compile_lg.py
"""
import os
from pathlib import Path
from typing import List
import k2
import sentencepiece as spm
import torch
lang_dir = Path("./data/lang_bpe_500")
corpus = "test_compile_lg_corpus.txt"
arpa = "test_compile_lg_3_gram.arpa"
G_fst_txt = "test_compile_lg_3_gram.fst.txt"
def get_word_ids(word_table: k2.SymbolTable, s: str) -> List[int]:
def generate_corpus():
s = """HELLO WORLD
HELLOA WORLDER
HELLOA WORLDER HELLO
HELLOA WORLDER"""
with open(corpus, "w") as f:
f.write(s)
def generate_arpa():
cmd = f"""
./shared/make_kn_lm.py \
-ngram-order 3 \
-text {corpus} \
-lm {arpa}
"""
Args:
word_table:
Word symbol table.
s:
A string consisting of space(s) separated words.
Returns:
Return a list of word IDs.
os.system(cmd)
def generate_G():
cmd = f"""
python3 -m kaldilm \
--read-symbol-table="{lang_dir}/words.txt" \
--disambig-symbol='#0' \
{arpa} > {G_fst_txt}
"""
ans = []
for w in s.split():
ans.append(word_table[w])
return ans
os.system(cmd)
def main():
assert lang_dir.exists(), f"{lang_dir} does not exist!"
LG = k2.Fsa.from_dict(torch.load(f"{lang_dir}/LG.pt", map_location="cpu"))
generate_corpus()
generate_arpa()
generate_G()
with open(G_fst_txt) as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
del G.aux_labels
G.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
G.draw("G.pdf", title="G")
sp = spm.SentencePieceProcessor()
sp.load(f"{lang_dir}/bpe.model")
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
L.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
word_table = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
s = "HELLO WORLD"
token_ids = sp.encode(s)
L = k2.arc_sort(L)
G = k2.arc_sort(G)
token_fsa = k2.linear_fsa(token_ids)
LG = k2.compose(L, G)
del LG.aux_labels
fsa = k2.intersect(LG, token_fsa)
fsa = k2.connect(fsa)
print(k2.to_dot(fsa))
print(fsa.properties_str)
LG = k2.determinize(LG)
LG = k2.connect(LG)
LG = k2.arc_sort(LG)
print(LG.properties_str)
# You can use https://dreampuf.github.io/GraphvizOnline/
# to visualize the output.
#
# You can see that the resulting fsa is not deterministic
# Note: LG is non-deterministic
#
# See https://shorturl.at/uIL69
# for visualization of the above fsa.
LG.draw("LG.pdf", title="LG")
# You can have a look at G.pdf and LG.pdf to get a feel
# what they look like
if __name__ == "__main__":

View File

@ -155,9 +155,14 @@ class HypothesisList(object):
key = hyp.key
if key in self:
old_hyp = self._data[key] # shallow copy
torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
if True:
old_hyp.log_prob = torch.logaddexp(
old_hyp.log_prob, hyp.log_prob
)
else:
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
if hyp.ngram_state_and_scores is not None:
for state, score in hyp.ngram_state_and_scores.items():
if (
@ -337,6 +342,7 @@ def modified_beam_search(
encoder_out: torch.Tensor,
beam: int = 4,
LG: Optional[k2.Fsa] = None,
ngram_lm_scale: float = 0.1,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
@ -349,11 +355,13 @@ def modified_beam_search(
Beam size.
LG:
Optional. Used for shallow fusion.
ngram_lm_scale:
Used only when LG is not None. The total score of a path is
am_score + ngram_lm_scale * ngram_lm_scale
Returns:
Return the decoded result.
"""
enable_shallow_fusion = LG is not None
ngram_lm_scale = 0.8
assert encoder_out.ndim == 3
@ -422,6 +430,7 @@ def modified_beam_search(
encoder_out_len.expand(decoder_out.size(0)),
decoder_out_len.expand(decoder_out.size(0)),
)
vocab_size = logits.size(-1)
# logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1)
@ -437,6 +446,9 @@ def modified_beam_search(
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
# import pdb
#
# pdb.set_trace()
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:]
@ -450,12 +462,15 @@ def modified_beam_search(
if enable_shallow_fusion and new_token != blank_id:
ngram_state_and_scores = shallow_fusion(
LG, new_token, hyp.ngram_state_and_scores
LG,
new_token,
hyp.ngram_state_and_scores,
vocab_size,
)
if len(ngram_state_and_scores) == 0:
continue
max_ngram_score = max(ngram_state_and_scores.values())
new_log_prob += ngram_lm_scale * max_ngram_score
new_log_prob = new_log_prob + ngram_lm_scale * max_ngram_score
# TODO: Get the maximum scores in ngram_state_and_scores
# and add it to new_log_prob
@ -468,6 +483,9 @@ def modified_beam_search(
B.add(new_hyp)
if len(B) == 0:
import logging
logging.info("\n*****\nEmpty states!\n***\n")
for h in A:
B.add(h)

View File

@ -139,6 +139,15 @@ def get_parser():
Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.1,
help="""Used when only --LG is provided.
The total score of a path is am_score + ngram_lm_scale * ngram_lm_score.
""",
)
return parser
@ -279,14 +288,18 @@ def decode_one_batch(
encoder_out=encoder_out_i,
beam=params.beam_size,
LG=LG,
ngram_lm_scale=params.ngram_lm_scale,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
s = "\n"
for h in hyps:
print(" ".join(h))
s += " ".join(h)
s += "\n"
logging.info(s)
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
@ -336,6 +349,8 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
if batch_idx > 10:
break
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
@ -452,9 +467,10 @@ def main():
logging.info(f"LG properties: {LG.properties_str}")
logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}")
# If LG is created by local/compile_lg.py, then it should be epsilon
# free as well as arc sorted
# free, deterministic, and arc sorted
assert "ArcSorted" in LG.properties_str
assert "EpsilonFree" in LG.properties_str
assert "Deterministic" in LG.properties_str
else:
LG = None
@ -501,6 +517,8 @@ def main():
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
if test_set == "test-other":
break
results_dict = decode_dataset(
dl=test_dl,
params=params,

View File

@ -18,61 +18,92 @@ from typing import Dict
import k2
import torch
import copy
def shallow_fusion(
LG: k2.Fsa,
token: int,
state_and_scores: Dict[int, torch.Tensor],
vocab_size: int,
) -> Dict[int, torch.Tensor]:
"""
Args:
LG:
An n-gram. It should be arc sorted and epsilon free.
An n-gram. It should be arc sorted, deterministic, and epsilon free.
token:
The input token ID.
state_and_scores:
The keys contain the current state we are in and the
values are the LM log_prob for reaching the corresponding
states from the start state.
vocab_size:
Vocabulary size, including the blank symbol. We assume that
token IDs >= vocab_size are disambig IDs (including the backoff
symbol #0).
Returns:
Return a new state_and_scores.
"""
row_splits = LG.arcs.row_splits(1)
arcs = LG.arcs.values()
state_and_scores = copy.deepcopy(state_and_scores)
current_states = list(state_and_scores.keys())
# Process out-going arcs with label being disambig tokens and #0
while len(current_states) > 0:
s = current_states.pop()
labels_begin = row_splits[s]
labels_end = row_splits[s + 1]
labels = LG.labels[labels_begin:labels_end].contiguous()
for i in reversed(range(labels.numel())):
lab = labels[i]
if lab == -1:
# Note: When sorting arcs, k2 treats arc labels as
# unsigned types
continue
if lab < vocab_size:
# Since LG is arc sorted, we can exit
# the for loop as soon as we have a label
# with ID less than vocab_size
break
# This is a diambig token or #0
idx = labels_begin + i
next_state = arcs[idx][1].item()
score = LG.scores[idx] + state_and_scores[s]
if next_state not in state_and_scores:
state_and_scores[next_state] = score
current_states.append(next_state)
else:
state_and_scores[next_state] = max(
score, state_and_scores[next_state]
)
current_states = list(state_and_scores.keys())
ans = dict()
for s in current_states:
labels_begin = row_splits[s]
labels_end = row_splits[s + 1]
labels = LG.labels[labels_begin:labels_end].contiguous()
# As LG is not deterministic, there may be multiple
# out-going arcs that with label equal to "token"
#
# Note: LG is arc sorted!
left = torch.bucketize(token, labels, right=False)
right = torch.bucketize(token, labels, right=True)
if labels[-1] == -1:
labels = labels[:-1]
if left >= right:
# There are no out-going arcs from this state
# that have label equal to "token"
pos = torch.searchsorted(labels, token)
if pos >= labels.numel() or labels[pos] != token:
continue
# Now we have
# labels[i] == token
# for
# left <= i < right
idx = labels_begin + pos
next_state = arcs[idx][1].item()
score = LG.scores[idx] + state_and_scores[s]
for i in range(left, right):
i += labels_begin
next_state = arcs[i][1].item()
score = LG.scores[i]
if next_state not in ans:
ans[next_state] = score
else:
ans[next_state] = max(score, ans[next_state])
if next_state not in ans:
ans[next_state] = score
else:
ans[next_state] = max(score, ans[next_state])
return ans