mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Keep disambig tokens and backoff arcs in LG.
This commit is contained in:
parent
954b4efff3
commit
5af23efa69
@ -105,31 +105,15 @@ def compile_LG(lang_dir: str) -> k2.Fsa:
|
|||||||
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
|
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Removing disambiguation symbols on LG")
|
|
||||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
|
||||||
# See https://github.com/k2-fsa/k2/issues/874
|
|
||||||
# for why we need to set LG.properties to None
|
|
||||||
LG.__dict__["_properties"] = None
|
|
||||||
|
|
||||||
logging.info("Removing epsilons")
|
|
||||||
LG = k2.remove_epsilon(LG)
|
|
||||||
logging.info(
|
|
||||||
f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Connecting")
|
|
||||||
LG = k2.connect(LG)
|
|
||||||
logging.info(
|
|
||||||
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Arc sorting LG")
|
logging.info("Arc sorting LG")
|
||||||
LG = k2.arc_sort(LG)
|
LG = k2.arc_sort(LG)
|
||||||
|
|
||||||
logging.info(f"LG properties: {LG.properties_str}")
|
logging.info(f"LG properties: {LG.properties_str}")
|
||||||
# Possible properties is:
|
# Possible properties is:
|
||||||
# "Valid|Nonempty|ArcSorted|EpsilonFree|MaybeAccessible|MaybeCoaccessible"
|
# "Valid|Nonempty|ArcSorted|ArcSortedAndDeterministic|EpsilonFree|MaybeAccessible|MaybeCoaccessible" # noqa
|
||||||
logging.info("Caution: LG is not deterministic!!!")
|
logging.info(
|
||||||
|
"Caution: LG is deterministic and contains disambig symbols!!!"
|
||||||
|
)
|
||||||
|
|
||||||
return LG
|
return LG
|
||||||
|
|
||||||
|
@ -21,58 +21,75 @@ To run this file, do:
|
|||||||
python ./local/test_compile_lg.py
|
python ./local/test_compile_lg.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
lang_dir = Path("./data/lang_bpe_500")
|
lang_dir = Path("./data/lang_bpe_500")
|
||||||
|
corpus = "test_compile_lg_corpus.txt"
|
||||||
|
arpa = "test_compile_lg_3_gram.arpa"
|
||||||
|
G_fst_txt = "test_compile_lg_3_gram.fst.txt"
|
||||||
|
|
||||||
|
|
||||||
def get_word_ids(word_table: k2.SymbolTable, s: str) -> List[int]:
|
def generate_corpus():
|
||||||
|
s = """HELLO WORLD
|
||||||
|
HELLOA WORLDER
|
||||||
|
HELLOA WORLDER HELLO
|
||||||
|
HELLOA WORLDER"""
|
||||||
|
with open(corpus, "w") as f:
|
||||||
|
f.write(s)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_arpa():
|
||||||
|
cmd = f"""
|
||||||
|
./shared/make_kn_lm.py \
|
||||||
|
-ngram-order 3 \
|
||||||
|
-text {corpus} \
|
||||||
|
-lm {arpa}
|
||||||
"""
|
"""
|
||||||
Args:
|
os.system(cmd)
|
||||||
word_table:
|
|
||||||
Word symbol table.
|
|
||||||
s:
|
def generate_G():
|
||||||
A string consisting of space(s) separated words.
|
cmd = f"""
|
||||||
Returns:
|
python3 -m kaldilm \
|
||||||
Return a list of word IDs.
|
--read-symbol-table="{lang_dir}/words.txt" \
|
||||||
|
--disambig-symbol='#0' \
|
||||||
|
{arpa} > {G_fst_txt}
|
||||||
"""
|
"""
|
||||||
ans = []
|
os.system(cmd)
|
||||||
for w in s.split():
|
|
||||||
ans.append(word_table[w])
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
assert lang_dir.exists(), f"{lang_dir} does not exist!"
|
generate_corpus()
|
||||||
LG = k2.Fsa.from_dict(torch.load(f"{lang_dir}/LG.pt", map_location="cpu"))
|
generate_arpa()
|
||||||
|
generate_G()
|
||||||
|
with open(G_fst_txt) as f:
|
||||||
|
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||||
|
del G.aux_labels
|
||||||
|
G.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
|
||||||
|
G.draw("G.pdf", title="G")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||||
sp.load(f"{lang_dir}/bpe.model")
|
L.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt")
|
||||||
|
L.aux_labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
|
||||||
|
|
||||||
word_table = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
|
L = k2.arc_sort(L)
|
||||||
s = "HELLO WORLD"
|
G = k2.arc_sort(G)
|
||||||
token_ids = sp.encode(s)
|
|
||||||
|
|
||||||
token_fsa = k2.linear_fsa(token_ids)
|
LG = k2.compose(L, G)
|
||||||
|
del LG.aux_labels
|
||||||
|
|
||||||
fsa = k2.intersect(LG, token_fsa)
|
LG = k2.determinize(LG)
|
||||||
fsa = k2.connect(fsa)
|
LG = k2.connect(LG)
|
||||||
print(k2.to_dot(fsa))
|
LG = k2.arc_sort(LG)
|
||||||
print(fsa.properties_str)
|
|
||||||
print(LG.properties_str)
|
print(LG.properties_str)
|
||||||
# You can use https://dreampuf.github.io/GraphvizOnline/
|
LG.draw("LG.pdf", title="LG")
|
||||||
# to visualize the output.
|
# You can have a look at G.pdf and LG.pdf to get a feel
|
||||||
#
|
# what they look like
|
||||||
# You can see that the resulting fsa is not deterministic
|
|
||||||
# Note: LG is non-deterministic
|
|
||||||
#
|
|
||||||
# See https://shorturl.at/uIL69
|
|
||||||
# for visualization of the above fsa.
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -155,9 +155,14 @@ class HypothesisList(object):
|
|||||||
key = hyp.key
|
key = hyp.key
|
||||||
if key in self:
|
if key in self:
|
||||||
old_hyp = self._data[key] # shallow copy
|
old_hyp = self._data[key] # shallow copy
|
||||||
torch.logaddexp(
|
|
||||||
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
if True:
|
||||||
)
|
old_hyp.log_prob = torch.logaddexp(
|
||||||
|
old_hyp.log_prob, hyp.log_prob
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
|
||||||
|
|
||||||
if hyp.ngram_state_and_scores is not None:
|
if hyp.ngram_state_and_scores is not None:
|
||||||
for state, score in hyp.ngram_state_and_scores.items():
|
for state, score in hyp.ngram_state_and_scores.items():
|
||||||
if (
|
if (
|
||||||
@ -337,6 +342,7 @@ def modified_beam_search(
|
|||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
LG: Optional[k2.Fsa] = None,
|
LG: Optional[k2.Fsa] = None,
|
||||||
|
ngram_lm_scale: float = 0.1,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
@ -349,11 +355,13 @@ def modified_beam_search(
|
|||||||
Beam size.
|
Beam size.
|
||||||
LG:
|
LG:
|
||||||
Optional. Used for shallow fusion.
|
Optional. Used for shallow fusion.
|
||||||
|
ngram_lm_scale:
|
||||||
|
Used only when LG is not None. The total score of a path is
|
||||||
|
am_score + ngram_lm_scale * ngram_lm_scale
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
Return the decoded result.
|
||||||
"""
|
"""
|
||||||
enable_shallow_fusion = LG is not None
|
enable_shallow_fusion = LG is not None
|
||||||
ngram_lm_scale = 0.8
|
|
||||||
|
|
||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
@ -422,6 +430,7 @@ def modified_beam_search(
|
|||||||
encoder_out_len.expand(decoder_out.size(0)),
|
encoder_out_len.expand(decoder_out.size(0)),
|
||||||
decoder_out_len.expand(decoder_out.size(0)),
|
decoder_out_len.expand(decoder_out.size(0)),
|
||||||
)
|
)
|
||||||
|
vocab_size = logits.size(-1)
|
||||||
# logits is of shape (num_hyps, vocab_size)
|
# logits is of shape (num_hyps, vocab_size)
|
||||||
log_probs = logits.log_softmax(dim=-1)
|
log_probs = logits.log_softmax(dim=-1)
|
||||||
|
|
||||||
@ -437,6 +446,9 @@ def modified_beam_search(
|
|||||||
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||||
topk_token_indexes = topk_token_indexes.tolist()
|
topk_token_indexes = topk_token_indexes.tolist()
|
||||||
|
|
||||||
|
# import pdb
|
||||||
|
#
|
||||||
|
# pdb.set_trace()
|
||||||
for i in range(len(topk_hyp_indexes)):
|
for i in range(len(topk_hyp_indexes)):
|
||||||
hyp = A[topk_hyp_indexes[i]]
|
hyp = A[topk_hyp_indexes[i]]
|
||||||
new_ys = hyp.ys[:]
|
new_ys = hyp.ys[:]
|
||||||
@ -450,12 +462,15 @@ def modified_beam_search(
|
|||||||
|
|
||||||
if enable_shallow_fusion and new_token != blank_id:
|
if enable_shallow_fusion and new_token != blank_id:
|
||||||
ngram_state_and_scores = shallow_fusion(
|
ngram_state_and_scores = shallow_fusion(
|
||||||
LG, new_token, hyp.ngram_state_and_scores
|
LG,
|
||||||
|
new_token,
|
||||||
|
hyp.ngram_state_and_scores,
|
||||||
|
vocab_size,
|
||||||
)
|
)
|
||||||
if len(ngram_state_and_scores) == 0:
|
if len(ngram_state_and_scores) == 0:
|
||||||
continue
|
continue
|
||||||
max_ngram_score = max(ngram_state_and_scores.values())
|
max_ngram_score = max(ngram_state_and_scores.values())
|
||||||
new_log_prob += ngram_lm_scale * max_ngram_score
|
new_log_prob = new_log_prob + ngram_lm_scale * max_ngram_score
|
||||||
|
|
||||||
# TODO: Get the maximum scores in ngram_state_and_scores
|
# TODO: Get the maximum scores in ngram_state_and_scores
|
||||||
# and add it to new_log_prob
|
# and add it to new_log_prob
|
||||||
@ -468,6 +483,9 @@ def modified_beam_search(
|
|||||||
|
|
||||||
B.add(new_hyp)
|
B.add(new_hyp)
|
||||||
if len(B) == 0:
|
if len(B) == 0:
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.info("\n*****\nEmpty states!\n***\n")
|
||||||
for h in A:
|
for h in A:
|
||||||
B.add(h)
|
B.add(h)
|
||||||
|
|
||||||
|
@ -139,6 +139,15 @@ def get_parser():
|
|||||||
Used only when --decoding-method is modified_beam_search.""",
|
Used only when --decoding-method is modified_beam_search.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="""Used when only --LG is provided.
|
||||||
|
The total score of a path is am_score + ngram_lm_scale * ngram_lm_score.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -279,14 +288,18 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
LG=LG,
|
LG=LG,
|
||||||
|
ngram_lm_scale=params.ngram_lm_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(sp.decode(hyp).split())
|
||||||
|
s = "\n"
|
||||||
for h in hyps:
|
for h in hyps:
|
||||||
print(" ".join(h))
|
s += " ".join(h)
|
||||||
|
s += "\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
@ -336,6 +349,8 @@ def decode_dataset(
|
|||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
if batch_idx > 10:
|
||||||
|
break
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
@ -452,9 +467,10 @@ def main():
|
|||||||
logging.info(f"LG properties: {LG.properties_str}")
|
logging.info(f"LG properties: {LG.properties_str}")
|
||||||
logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}")
|
logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}")
|
||||||
# If LG is created by local/compile_lg.py, then it should be epsilon
|
# If LG is created by local/compile_lg.py, then it should be epsilon
|
||||||
# free as well as arc sorted
|
# free, deterministic, and arc sorted
|
||||||
assert "ArcSorted" in LG.properties_str
|
assert "ArcSorted" in LG.properties_str
|
||||||
assert "EpsilonFree" in LG.properties_str
|
assert "EpsilonFree" in LG.properties_str
|
||||||
|
assert "Deterministic" in LG.properties_str
|
||||||
else:
|
else:
|
||||||
LG = None
|
LG = None
|
||||||
|
|
||||||
@ -501,6 +517,8 @@ def main():
|
|||||||
test_dl = [test_clean_dl, test_other_dl]
|
test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
|
if test_set == "test-other":
|
||||||
|
break
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
|
@ -18,61 +18,92 @@ from typing import Dict
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
def shallow_fusion(
|
def shallow_fusion(
|
||||||
LG: k2.Fsa,
|
LG: k2.Fsa,
|
||||||
token: int,
|
token: int,
|
||||||
state_and_scores: Dict[int, torch.Tensor],
|
state_and_scores: Dict[int, torch.Tensor],
|
||||||
|
vocab_size: int,
|
||||||
) -> Dict[int, torch.Tensor]:
|
) -> Dict[int, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
LG:
|
LG:
|
||||||
An n-gram. It should be arc sorted and epsilon free.
|
An n-gram. It should be arc sorted, deterministic, and epsilon free.
|
||||||
token:
|
token:
|
||||||
The input token ID.
|
The input token ID.
|
||||||
state_and_scores:
|
state_and_scores:
|
||||||
The keys contain the current state we are in and the
|
The keys contain the current state we are in and the
|
||||||
values are the LM log_prob for reaching the corresponding
|
values are the LM log_prob for reaching the corresponding
|
||||||
states from the start state.
|
states from the start state.
|
||||||
|
vocab_size:
|
||||||
|
Vocabulary size, including the blank symbol. We assume that
|
||||||
|
token IDs >= vocab_size are disambig IDs (including the backoff
|
||||||
|
symbol #0).
|
||||||
Returns:
|
Returns:
|
||||||
Return a new state_and_scores.
|
Return a new state_and_scores.
|
||||||
"""
|
"""
|
||||||
row_splits = LG.arcs.row_splits(1)
|
row_splits = LG.arcs.row_splits(1)
|
||||||
arcs = LG.arcs.values()
|
arcs = LG.arcs.values()
|
||||||
|
|
||||||
|
state_and_scores = copy.deepcopy(state_and_scores)
|
||||||
|
|
||||||
current_states = list(state_and_scores.keys())
|
current_states = list(state_and_scores.keys())
|
||||||
|
|
||||||
|
# Process out-going arcs with label being disambig tokens and #0
|
||||||
|
while len(current_states) > 0:
|
||||||
|
s = current_states.pop()
|
||||||
|
labels_begin = row_splits[s]
|
||||||
|
labels_end = row_splits[s + 1]
|
||||||
|
labels = LG.labels[labels_begin:labels_end].contiguous()
|
||||||
|
|
||||||
|
for i in reversed(range(labels.numel())):
|
||||||
|
lab = labels[i]
|
||||||
|
if lab == -1:
|
||||||
|
# Note: When sorting arcs, k2 treats arc labels as
|
||||||
|
# unsigned types
|
||||||
|
continue
|
||||||
|
|
||||||
|
if lab < vocab_size:
|
||||||
|
# Since LG is arc sorted, we can exit
|
||||||
|
# the for loop as soon as we have a label
|
||||||
|
# with ID less than vocab_size
|
||||||
|
break
|
||||||
|
|
||||||
|
# This is a diambig token or #0
|
||||||
|
idx = labels_begin + i
|
||||||
|
next_state = arcs[idx][1].item()
|
||||||
|
score = LG.scores[idx] + state_and_scores[s]
|
||||||
|
if next_state not in state_and_scores:
|
||||||
|
state_and_scores[next_state] = score
|
||||||
|
current_states.append(next_state)
|
||||||
|
else:
|
||||||
|
state_and_scores[next_state] = max(
|
||||||
|
score, state_and_scores[next_state]
|
||||||
|
)
|
||||||
|
|
||||||
|
current_states = list(state_and_scores.keys())
|
||||||
ans = dict()
|
ans = dict()
|
||||||
for s in current_states:
|
for s in current_states:
|
||||||
labels_begin = row_splits[s]
|
labels_begin = row_splits[s]
|
||||||
labels_end = row_splits[s + 1]
|
labels_end = row_splits[s + 1]
|
||||||
labels = LG.labels[labels_begin:labels_end].contiguous()
|
labels = LG.labels[labels_begin:labels_end].contiguous()
|
||||||
|
|
||||||
# As LG is not deterministic, there may be multiple
|
if labels[-1] == -1:
|
||||||
# out-going arcs that with label equal to "token"
|
labels = labels[:-1]
|
||||||
#
|
|
||||||
# Note: LG is arc sorted!
|
|
||||||
left = torch.bucketize(token, labels, right=False)
|
|
||||||
right = torch.bucketize(token, labels, right=True)
|
|
||||||
|
|
||||||
if left >= right:
|
pos = torch.searchsorted(labels, token)
|
||||||
# There are no out-going arcs from this state
|
if pos >= labels.numel() or labels[pos] != token:
|
||||||
# that have label equal to "token"
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Now we have
|
idx = labels_begin + pos
|
||||||
# labels[i] == token
|
next_state = arcs[idx][1].item()
|
||||||
# for
|
score = LG.scores[idx] + state_and_scores[s]
|
||||||
# left <= i < right
|
|
||||||
|
|
||||||
for i in range(left, right):
|
if next_state not in ans:
|
||||||
i += labels_begin
|
ans[next_state] = score
|
||||||
next_state = arcs[i][1].item()
|
else:
|
||||||
score = LG.scores[i]
|
ans[next_state] = max(score, ans[next_state])
|
||||||
if next_state not in ans:
|
|
||||||
ans[next_state] = score
|
|
||||||
else:
|
|
||||||
ans[next_state] = max(score, ans[next_state])
|
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
Loading…
x
Reference in New Issue
Block a user