mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
Keep disambig tokens and backoff arcs in LG.
This commit is contained in:
parent
954b4efff3
commit
5af23efa69
@ -105,31 +105,15 @@ def compile_LG(lang_dir: str) -> k2.Fsa:
|
||||
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
|
||||
)
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set LG.properties to None
|
||||
LG.__dict__["_properties"] = None
|
||||
|
||||
logging.info("Removing epsilons")
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(
|
||||
f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}"
|
||||
)
|
||||
|
||||
logging.info("Connecting")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(
|
||||
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
|
||||
)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
logging.info(f"LG properties: {LG.properties_str}")
|
||||
# Possible properties is:
|
||||
# "Valid|Nonempty|ArcSorted|EpsilonFree|MaybeAccessible|MaybeCoaccessible"
|
||||
logging.info("Caution: LG is not deterministic!!!")
|
||||
# "Valid|Nonempty|ArcSorted|ArcSortedAndDeterministic|EpsilonFree|MaybeAccessible|MaybeCoaccessible" # noqa
|
||||
logging.info(
|
||||
"Caution: LG is deterministic and contains disambig symbols!!!"
|
||||
)
|
||||
|
||||
return LG
|
||||
|
||||
|
@ -21,58 +21,75 @@ To run this file, do:
|
||||
python ./local/test_compile_lg.py
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
lang_dir = Path("./data/lang_bpe_500")
|
||||
corpus = "test_compile_lg_corpus.txt"
|
||||
arpa = "test_compile_lg_3_gram.arpa"
|
||||
G_fst_txt = "test_compile_lg_3_gram.fst.txt"
|
||||
|
||||
|
||||
def get_word_ids(word_table: k2.SymbolTable, s: str) -> List[int]:
|
||||
def generate_corpus():
|
||||
s = """HELLO WORLD
|
||||
HELLOA WORLDER
|
||||
HELLOA WORLDER HELLO
|
||||
HELLOA WORLDER"""
|
||||
with open(corpus, "w") as f:
|
||||
f.write(s)
|
||||
|
||||
|
||||
def generate_arpa():
|
||||
cmd = f"""
|
||||
./shared/make_kn_lm.py \
|
||||
-ngram-order 3 \
|
||||
-text {corpus} \
|
||||
-lm {arpa}
|
||||
"""
|
||||
Args:
|
||||
word_table:
|
||||
Word symbol table.
|
||||
s:
|
||||
A string consisting of space(s) separated words.
|
||||
Returns:
|
||||
Return a list of word IDs.
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
def generate_G():
|
||||
cmd = f"""
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="{lang_dir}/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
{arpa} > {G_fst_txt}
|
||||
"""
|
||||
ans = []
|
||||
for w in s.split():
|
||||
ans.append(word_table[w])
|
||||
return ans
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
def main():
|
||||
assert lang_dir.exists(), f"{lang_dir} does not exist!"
|
||||
LG = k2.Fsa.from_dict(torch.load(f"{lang_dir}/LG.pt", map_location="cpu"))
|
||||
generate_corpus()
|
||||
generate_arpa()
|
||||
generate_G()
|
||||
with open(G_fst_txt) as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
del G.aux_labels
|
||||
G.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
|
||||
G.draw("G.pdf", title="G")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(f"{lang_dir}/bpe.model")
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
L.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt")
|
||||
L.aux_labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
|
||||
|
||||
word_table = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
|
||||
s = "HELLO WORLD"
|
||||
token_ids = sp.encode(s)
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
token_fsa = k2.linear_fsa(token_ids)
|
||||
LG = k2.compose(L, G)
|
||||
del LG.aux_labels
|
||||
|
||||
fsa = k2.intersect(LG, token_fsa)
|
||||
fsa = k2.connect(fsa)
|
||||
print(k2.to_dot(fsa))
|
||||
print(fsa.properties_str)
|
||||
LG = k2.determinize(LG)
|
||||
LG = k2.connect(LG)
|
||||
LG = k2.arc_sort(LG)
|
||||
print(LG.properties_str)
|
||||
# You can use https://dreampuf.github.io/GraphvizOnline/
|
||||
# to visualize the output.
|
||||
#
|
||||
# You can see that the resulting fsa is not deterministic
|
||||
# Note: LG is non-deterministic
|
||||
#
|
||||
# See https://shorturl.at/uIL69
|
||||
# for visualization of the above fsa.
|
||||
LG.draw("LG.pdf", title="LG")
|
||||
# You can have a look at G.pdf and LG.pdf to get a feel
|
||||
# what they look like
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -155,9 +155,14 @@ class HypothesisList(object):
|
||||
key = hyp.key
|
||||
if key in self:
|
||||
old_hyp = self._data[key] # shallow copy
|
||||
torch.logaddexp(
|
||||
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
||||
)
|
||||
|
||||
if True:
|
||||
old_hyp.log_prob = torch.logaddexp(
|
||||
old_hyp.log_prob, hyp.log_prob
|
||||
)
|
||||
else:
|
||||
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
|
||||
|
||||
if hyp.ngram_state_and_scores is not None:
|
||||
for state, score in hyp.ngram_state_and_scores.items():
|
||||
if (
|
||||
@ -337,6 +342,7 @@ def modified_beam_search(
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
LG: Optional[k2.Fsa] = None,
|
||||
ngram_lm_scale: float = 0.1,
|
||||
) -> List[int]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
@ -349,11 +355,13 @@ def modified_beam_search(
|
||||
Beam size.
|
||||
LG:
|
||||
Optional. Used for shallow fusion.
|
||||
ngram_lm_scale:
|
||||
Used only when LG is not None. The total score of a path is
|
||||
am_score + ngram_lm_scale * ngram_lm_scale
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
enable_shallow_fusion = LG is not None
|
||||
ngram_lm_scale = 0.8
|
||||
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
@ -422,6 +430,7 @@ def modified_beam_search(
|
||||
encoder_out_len.expand(decoder_out.size(0)),
|
||||
decoder_out_len.expand(decoder_out.size(0)),
|
||||
)
|
||||
vocab_size = logits.size(-1)
|
||||
# logits is of shape (num_hyps, vocab_size)
|
||||
log_probs = logits.log_softmax(dim=-1)
|
||||
|
||||
@ -437,6 +446,9 @@ def modified_beam_search(
|
||||
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||
topk_token_indexes = topk_token_indexes.tolist()
|
||||
|
||||
# import pdb
|
||||
#
|
||||
# pdb.set_trace()
|
||||
for i in range(len(topk_hyp_indexes)):
|
||||
hyp = A[topk_hyp_indexes[i]]
|
||||
new_ys = hyp.ys[:]
|
||||
@ -450,12 +462,15 @@ def modified_beam_search(
|
||||
|
||||
if enable_shallow_fusion and new_token != blank_id:
|
||||
ngram_state_and_scores = shallow_fusion(
|
||||
LG, new_token, hyp.ngram_state_and_scores
|
||||
LG,
|
||||
new_token,
|
||||
hyp.ngram_state_and_scores,
|
||||
vocab_size,
|
||||
)
|
||||
if len(ngram_state_and_scores) == 0:
|
||||
continue
|
||||
max_ngram_score = max(ngram_state_and_scores.values())
|
||||
new_log_prob += ngram_lm_scale * max_ngram_score
|
||||
new_log_prob = new_log_prob + ngram_lm_scale * max_ngram_score
|
||||
|
||||
# TODO: Get the maximum scores in ngram_state_and_scores
|
||||
# and add it to new_log_prob
|
||||
@ -468,6 +483,9 @@ def modified_beam_search(
|
||||
|
||||
B.add(new_hyp)
|
||||
if len(B) == 0:
|
||||
import logging
|
||||
|
||||
logging.info("\n*****\nEmpty states!\n***\n")
|
||||
for h in A:
|
||||
B.add(h)
|
||||
|
||||
|
@ -139,6 +139,15 @@ def get_parser():
|
||||
Used only when --decoding-method is modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="""Used when only --LG is provided.
|
||||
The total score of a path is am_score + ngram_lm_scale * ngram_lm_score.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -279,14 +288,18 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
LG=LG,
|
||||
ngram_lm_scale=params.ngram_lm_scale,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
s = "\n"
|
||||
for h in hyps:
|
||||
print(" ".join(h))
|
||||
s += " ".join(h)
|
||||
s += "\n"
|
||||
logging.info(s)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
@ -336,6 +349,8 @@ def decode_dataset(
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
if batch_idx > 10:
|
||||
break
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
@ -452,9 +467,10 @@ def main():
|
||||
logging.info(f"LG properties: {LG.properties_str}")
|
||||
logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}")
|
||||
# If LG is created by local/compile_lg.py, then it should be epsilon
|
||||
# free as well as arc sorted
|
||||
# free, deterministic, and arc sorted
|
||||
assert "ArcSorted" in LG.properties_str
|
||||
assert "EpsilonFree" in LG.properties_str
|
||||
assert "Deterministic" in LG.properties_str
|
||||
else:
|
||||
LG = None
|
||||
|
||||
@ -501,6 +517,8 @@ def main():
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
if test_set == "test-other":
|
||||
break
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
|
@ -18,61 +18,92 @@ from typing import Dict
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import copy
|
||||
|
||||
|
||||
def shallow_fusion(
|
||||
LG: k2.Fsa,
|
||||
token: int,
|
||||
state_and_scores: Dict[int, torch.Tensor],
|
||||
vocab_size: int,
|
||||
) -> Dict[int, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
LG:
|
||||
An n-gram. It should be arc sorted and epsilon free.
|
||||
An n-gram. It should be arc sorted, deterministic, and epsilon free.
|
||||
token:
|
||||
The input token ID.
|
||||
state_and_scores:
|
||||
The keys contain the current state we are in and the
|
||||
values are the LM log_prob for reaching the corresponding
|
||||
states from the start state.
|
||||
vocab_size:
|
||||
Vocabulary size, including the blank symbol. We assume that
|
||||
token IDs >= vocab_size are disambig IDs (including the backoff
|
||||
symbol #0).
|
||||
Returns:
|
||||
Return a new state_and_scores.
|
||||
"""
|
||||
row_splits = LG.arcs.row_splits(1)
|
||||
arcs = LG.arcs.values()
|
||||
|
||||
state_and_scores = copy.deepcopy(state_and_scores)
|
||||
|
||||
current_states = list(state_and_scores.keys())
|
||||
|
||||
# Process out-going arcs with label being disambig tokens and #0
|
||||
while len(current_states) > 0:
|
||||
s = current_states.pop()
|
||||
labels_begin = row_splits[s]
|
||||
labels_end = row_splits[s + 1]
|
||||
labels = LG.labels[labels_begin:labels_end].contiguous()
|
||||
|
||||
for i in reversed(range(labels.numel())):
|
||||
lab = labels[i]
|
||||
if lab == -1:
|
||||
# Note: When sorting arcs, k2 treats arc labels as
|
||||
# unsigned types
|
||||
continue
|
||||
|
||||
if lab < vocab_size:
|
||||
# Since LG is arc sorted, we can exit
|
||||
# the for loop as soon as we have a label
|
||||
# with ID less than vocab_size
|
||||
break
|
||||
|
||||
# This is a diambig token or #0
|
||||
idx = labels_begin + i
|
||||
next_state = arcs[idx][1].item()
|
||||
score = LG.scores[idx] + state_and_scores[s]
|
||||
if next_state not in state_and_scores:
|
||||
state_and_scores[next_state] = score
|
||||
current_states.append(next_state)
|
||||
else:
|
||||
state_and_scores[next_state] = max(
|
||||
score, state_and_scores[next_state]
|
||||
)
|
||||
|
||||
current_states = list(state_and_scores.keys())
|
||||
ans = dict()
|
||||
for s in current_states:
|
||||
labels_begin = row_splits[s]
|
||||
labels_end = row_splits[s + 1]
|
||||
labels = LG.labels[labels_begin:labels_end].contiguous()
|
||||
|
||||
# As LG is not deterministic, there may be multiple
|
||||
# out-going arcs that with label equal to "token"
|
||||
#
|
||||
# Note: LG is arc sorted!
|
||||
left = torch.bucketize(token, labels, right=False)
|
||||
right = torch.bucketize(token, labels, right=True)
|
||||
if labels[-1] == -1:
|
||||
labels = labels[:-1]
|
||||
|
||||
if left >= right:
|
||||
# There are no out-going arcs from this state
|
||||
# that have label equal to "token"
|
||||
pos = torch.searchsorted(labels, token)
|
||||
if pos >= labels.numel() or labels[pos] != token:
|
||||
continue
|
||||
|
||||
# Now we have
|
||||
# labels[i] == token
|
||||
# for
|
||||
# left <= i < right
|
||||
idx = labels_begin + pos
|
||||
next_state = arcs[idx][1].item()
|
||||
score = LG.scores[idx] + state_and_scores[s]
|
||||
|
||||
for i in range(left, right):
|
||||
i += labels_begin
|
||||
next_state = arcs[i][1].item()
|
||||
score = LG.scores[i]
|
||||
if next_state not in ans:
|
||||
ans[next_state] = score
|
||||
else:
|
||||
ans[next_state] = max(score, ans[next_state])
|
||||
if next_state not in ans:
|
||||
ans[next_state] = score
|
||||
else:
|
||||
ans[next_state] = max(score, ans[next_state])
|
||||
|
||||
return ans
|
||||
|
Loading…
x
Reference in New Issue
Block a user