Keep disambig tokens and backoff arcs in LG.

This commit is contained in:
Fangjun Kuang 2022-02-10 20:28:59 +08:00
parent 954b4efff3
commit 5af23efa69
5 changed files with 152 additions and 84 deletions

View File

@ -105,31 +105,15 @@ def compile_LG(lang_dir: str) -> k2.Fsa:
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}" f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
) )
logging.info("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set LG.properties to None
LG.__dict__["_properties"] = None
logging.info("Removing epsilons")
LG = k2.remove_epsilon(LG)
logging.info(
f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Connecting")
LG = k2.connect(LG)
logging.info(
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Arc sorting LG") logging.info("Arc sorting LG")
LG = k2.arc_sort(LG) LG = k2.arc_sort(LG)
logging.info(f"LG properties: {LG.properties_str}") logging.info(f"LG properties: {LG.properties_str}")
# Possible properties is: # Possible properties is:
# "Valid|Nonempty|ArcSorted|EpsilonFree|MaybeAccessible|MaybeCoaccessible" # "Valid|Nonempty|ArcSorted|ArcSortedAndDeterministic|EpsilonFree|MaybeAccessible|MaybeCoaccessible" # noqa
logging.info("Caution: LG is not deterministic!!!") logging.info(
"Caution: LG is deterministic and contains disambig symbols!!!"
)
return LG return LG

View File

@ -21,58 +21,75 @@ To run this file, do:
python ./local/test_compile_lg.py python ./local/test_compile_lg.py
""" """
import os
from pathlib import Path from pathlib import Path
from typing import List
import k2 import k2
import sentencepiece as spm
import torch import torch
lang_dir = Path("./data/lang_bpe_500") lang_dir = Path("./data/lang_bpe_500")
corpus = "test_compile_lg_corpus.txt"
arpa = "test_compile_lg_3_gram.arpa"
G_fst_txt = "test_compile_lg_3_gram.fst.txt"
def get_word_ids(word_table: k2.SymbolTable, s: str) -> List[int]: def generate_corpus():
s = """HELLO WORLD
HELLOA WORLDER
HELLOA WORLDER HELLO
HELLOA WORLDER"""
with open(corpus, "w") as f:
f.write(s)
def generate_arpa():
cmd = f"""
./shared/make_kn_lm.py \
-ngram-order 3 \
-text {corpus} \
-lm {arpa}
""" """
Args: os.system(cmd)
word_table:
Word symbol table.
s: def generate_G():
A string consisting of space(s) separated words. cmd = f"""
Returns: python3 -m kaldilm \
Return a list of word IDs. --read-symbol-table="{lang_dir}/words.txt" \
--disambig-symbol='#0' \
{arpa} > {G_fst_txt}
""" """
ans = [] os.system(cmd)
for w in s.split():
ans.append(word_table[w])
return ans
def main(): def main():
assert lang_dir.exists(), f"{lang_dir} does not exist!" generate_corpus()
LG = k2.Fsa.from_dict(torch.load(f"{lang_dir}/LG.pt", map_location="cpu")) generate_arpa()
generate_G()
with open(G_fst_txt) as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
del G.aux_labels
G.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
G.draw("G.pdf", title="G")
sp = spm.SentencePieceProcessor() L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
sp.load(f"{lang_dir}/bpe.model") L.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
word_table = k2.SymbolTable.from_file(f"{lang_dir}/words.txt") L = k2.arc_sort(L)
s = "HELLO WORLD" G = k2.arc_sort(G)
token_ids = sp.encode(s)
token_fsa = k2.linear_fsa(token_ids) LG = k2.compose(L, G)
del LG.aux_labels
fsa = k2.intersect(LG, token_fsa) LG = k2.determinize(LG)
fsa = k2.connect(fsa) LG = k2.connect(LG)
print(k2.to_dot(fsa)) LG = k2.arc_sort(LG)
print(fsa.properties_str)
print(LG.properties_str) print(LG.properties_str)
# You can use https://dreampuf.github.io/GraphvizOnline/ LG.draw("LG.pdf", title="LG")
# to visualize the output. # You can have a look at G.pdf and LG.pdf to get a feel
# # what they look like
# You can see that the resulting fsa is not deterministic
# Note: LG is non-deterministic
#
# See https://shorturl.at/uIL69
# for visualization of the above fsa.
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -155,9 +155,14 @@ class HypothesisList(object):
key = hyp.key key = hyp.key
if key in self: if key in self:
old_hyp = self._data[key] # shallow copy old_hyp = self._data[key] # shallow copy
torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob if True:
) old_hyp.log_prob = torch.logaddexp(
old_hyp.log_prob, hyp.log_prob
)
else:
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
if hyp.ngram_state_and_scores is not None: if hyp.ngram_state_and_scores is not None:
for state, score in hyp.ngram_state_and_scores.items(): for state, score in hyp.ngram_state_and_scores.items():
if ( if (
@ -337,6 +342,7 @@ def modified_beam_search(
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
LG: Optional[k2.Fsa] = None, LG: Optional[k2.Fsa] = None,
ngram_lm_scale: float = 0.1,
) -> List[int]: ) -> List[int]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -349,11 +355,13 @@ def modified_beam_search(
Beam size. Beam size.
LG: LG:
Optional. Used for shallow fusion. Optional. Used for shallow fusion.
ngram_lm_scale:
Used only when LG is not None. The total score of a path is
am_score + ngram_lm_scale * ngram_lm_scale
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
enable_shallow_fusion = LG is not None enable_shallow_fusion = LG is not None
ngram_lm_scale = 0.8
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
@ -422,6 +430,7 @@ def modified_beam_search(
encoder_out_len.expand(decoder_out.size(0)), encoder_out_len.expand(decoder_out.size(0)),
decoder_out_len.expand(decoder_out.size(0)), decoder_out_len.expand(decoder_out.size(0)),
) )
vocab_size = logits.size(-1)
# logits is of shape (num_hyps, vocab_size) # logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) log_probs = logits.log_softmax(dim=-1)
@ -437,6 +446,9 @@ def modified_beam_search(
topk_hyp_indexes = topk_hyp_indexes.tolist() topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist() topk_token_indexes = topk_token_indexes.tolist()
# import pdb
#
# pdb.set_trace()
for i in range(len(topk_hyp_indexes)): for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]] hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:] new_ys = hyp.ys[:]
@ -450,12 +462,15 @@ def modified_beam_search(
if enable_shallow_fusion and new_token != blank_id: if enable_shallow_fusion and new_token != blank_id:
ngram_state_and_scores = shallow_fusion( ngram_state_and_scores = shallow_fusion(
LG, new_token, hyp.ngram_state_and_scores LG,
new_token,
hyp.ngram_state_and_scores,
vocab_size,
) )
if len(ngram_state_and_scores) == 0: if len(ngram_state_and_scores) == 0:
continue continue
max_ngram_score = max(ngram_state_and_scores.values()) max_ngram_score = max(ngram_state_and_scores.values())
new_log_prob += ngram_lm_scale * max_ngram_score new_log_prob = new_log_prob + ngram_lm_scale * max_ngram_score
# TODO: Get the maximum scores in ngram_state_and_scores # TODO: Get the maximum scores in ngram_state_and_scores
# and add it to new_log_prob # and add it to new_log_prob
@ -468,6 +483,9 @@ def modified_beam_search(
B.add(new_hyp) B.add(new_hyp)
if len(B) == 0: if len(B) == 0:
import logging
logging.info("\n*****\nEmpty states!\n***\n")
for h in A: for h in A:
B.add(h) B.add(h)

View File

@ -139,6 +139,15 @@ def get_parser():
Used only when --decoding-method is modified_beam_search.""", Used only when --decoding-method is modified_beam_search.""",
) )
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.1,
help="""Used when only --LG is provided.
The total score of a path is am_score + ngram_lm_scale * ngram_lm_score.
""",
)
return parser return parser
@ -279,14 +288,18 @@ def decode_one_batch(
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
LG=LG, LG=LG,
ngram_lm_scale=params.ngram_lm_scale,
) )
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.decode(hyp).split()) hyps.append(sp.decode(hyp).split())
s = "\n"
for h in hyps: for h in hyps:
print(" ".join(h)) s += " ".join(h)
s += "\n"
logging.info(s)
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
@ -336,6 +349,8 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
if batch_idx > 10:
break
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
@ -452,9 +467,10 @@ def main():
logging.info(f"LG properties: {LG.properties_str}") logging.info(f"LG properties: {LG.properties_str}")
logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}") logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}")
# If LG is created by local/compile_lg.py, then it should be epsilon # If LG is created by local/compile_lg.py, then it should be epsilon
# free as well as arc sorted # free, deterministic, and arc sorted
assert "ArcSorted" in LG.properties_str assert "ArcSorted" in LG.properties_str
assert "EpsilonFree" in LG.properties_str assert "EpsilonFree" in LG.properties_str
assert "Deterministic" in LG.properties_str
else: else:
LG = None LG = None
@ -501,6 +517,8 @@ def main():
test_dl = [test_clean_dl, test_other_dl] test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
if test_set == "test-other":
break
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,

View File

@ -18,61 +18,92 @@ from typing import Dict
import k2 import k2
import torch import torch
import copy
def shallow_fusion( def shallow_fusion(
LG: k2.Fsa, LG: k2.Fsa,
token: int, token: int,
state_and_scores: Dict[int, torch.Tensor], state_and_scores: Dict[int, torch.Tensor],
vocab_size: int,
) -> Dict[int, torch.Tensor]: ) -> Dict[int, torch.Tensor]:
""" """
Args: Args:
LG: LG:
An n-gram. It should be arc sorted and epsilon free. An n-gram. It should be arc sorted, deterministic, and epsilon free.
token: token:
The input token ID. The input token ID.
state_and_scores: state_and_scores:
The keys contain the current state we are in and the The keys contain the current state we are in and the
values are the LM log_prob for reaching the corresponding values are the LM log_prob for reaching the corresponding
states from the start state. states from the start state.
vocab_size:
Vocabulary size, including the blank symbol. We assume that
token IDs >= vocab_size are disambig IDs (including the backoff
symbol #0).
Returns: Returns:
Return a new state_and_scores. Return a new state_and_scores.
""" """
row_splits = LG.arcs.row_splits(1) row_splits = LG.arcs.row_splits(1)
arcs = LG.arcs.values() arcs = LG.arcs.values()
state_and_scores = copy.deepcopy(state_and_scores)
current_states = list(state_and_scores.keys()) current_states = list(state_and_scores.keys())
# Process out-going arcs with label being disambig tokens and #0
while len(current_states) > 0:
s = current_states.pop()
labels_begin = row_splits[s]
labels_end = row_splits[s + 1]
labels = LG.labels[labels_begin:labels_end].contiguous()
for i in reversed(range(labels.numel())):
lab = labels[i]
if lab == -1:
# Note: When sorting arcs, k2 treats arc labels as
# unsigned types
continue
if lab < vocab_size:
# Since LG is arc sorted, we can exit
# the for loop as soon as we have a label
# with ID less than vocab_size
break
# This is a diambig token or #0
idx = labels_begin + i
next_state = arcs[idx][1].item()
score = LG.scores[idx] + state_and_scores[s]
if next_state not in state_and_scores:
state_and_scores[next_state] = score
current_states.append(next_state)
else:
state_and_scores[next_state] = max(
score, state_and_scores[next_state]
)
current_states = list(state_and_scores.keys())
ans = dict() ans = dict()
for s in current_states: for s in current_states:
labels_begin = row_splits[s] labels_begin = row_splits[s]
labels_end = row_splits[s + 1] labels_end = row_splits[s + 1]
labels = LG.labels[labels_begin:labels_end].contiguous() labels = LG.labels[labels_begin:labels_end].contiguous()
# As LG is not deterministic, there may be multiple if labels[-1] == -1:
# out-going arcs that with label equal to "token" labels = labels[:-1]
#
# Note: LG is arc sorted!
left = torch.bucketize(token, labels, right=False)
right = torch.bucketize(token, labels, right=True)
if left >= right: pos = torch.searchsorted(labels, token)
# There are no out-going arcs from this state if pos >= labels.numel() or labels[pos] != token:
# that have label equal to "token"
continue continue
# Now we have idx = labels_begin + pos
# labels[i] == token next_state = arcs[idx][1].item()
# for score = LG.scores[idx] + state_and_scores[s]
# left <= i < right
for i in range(left, right): if next_state not in ans:
i += labels_begin ans[next_state] = score
next_state = arcs[i][1].item() else:
score = LG.scores[i] ans[next_state] = max(score, ans[next_state])
if next_state not in ans:
ans[next_state] = score
else:
ans[next_state] = max(score, ans[next_state])
return ans return ans