diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index d87fdf1b9..97301691e 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -105,31 +105,15 @@ def compile_LG(lang_dir: str) -> k2.Fsa: f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}" ) - logging.info("Removing disambiguation symbols on LG") - LG.labels[LG.labels >= first_token_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set LG.properties to None - LG.__dict__["_properties"] = None - - logging.info("Removing epsilons") - LG = k2.remove_epsilon(LG) - logging.info( - f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}" - ) - - logging.info("Connecting") - LG = k2.connect(LG) - logging.info( - f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}" - ) - logging.info("Arc sorting LG") LG = k2.arc_sort(LG) logging.info(f"LG properties: {LG.properties_str}") # Possible properties is: - # "Valid|Nonempty|ArcSorted|EpsilonFree|MaybeAccessible|MaybeCoaccessible" - logging.info("Caution: LG is not deterministic!!!") + # "Valid|Nonempty|ArcSorted|ArcSortedAndDeterministic|EpsilonFree|MaybeAccessible|MaybeCoaccessible" # noqa + logging.info( + "Caution: LG is deterministic and contains disambig symbols!!!" + ) return LG diff --git a/egs/librispeech/ASR/local/test_compile_lg.py b/egs/librispeech/ASR/local/test_compile_lg.py index 35eee0586..5131551f9 100755 --- a/egs/librispeech/ASR/local/test_compile_lg.py +++ b/egs/librispeech/ASR/local/test_compile_lg.py @@ -21,58 +21,75 @@ To run this file, do: python ./local/test_compile_lg.py """ +import os + from pathlib import Path -from typing import List import k2 -import sentencepiece as spm import torch lang_dir = Path("./data/lang_bpe_500") +corpus = "test_compile_lg_corpus.txt" +arpa = "test_compile_lg_3_gram.arpa" +G_fst_txt = "test_compile_lg_3_gram.fst.txt" -def get_word_ids(word_table: k2.SymbolTable, s: str) -> List[int]: +def generate_corpus(): + s = """HELLO WORLD +HELLOA WORLDER +HELLOA WORLDER HELLO +HELLOA WORLDER""" + with open(corpus, "w") as f: + f.write(s) + + +def generate_arpa(): + cmd = f""" + ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text {corpus} \ + -lm {arpa} """ - Args: - word_table: - Word symbol table. - s: - A string consisting of space(s) separated words. - Returns: - Return a list of word IDs. + os.system(cmd) + + +def generate_G(): + cmd = f""" + python3 -m kaldilm \ + --read-symbol-table="{lang_dir}/words.txt" \ + --disambig-symbol='#0' \ + {arpa} > {G_fst_txt} """ - ans = [] - for w in s.split(): - ans.append(word_table[w]) - return ans + os.system(cmd) def main(): - assert lang_dir.exists(), f"{lang_dir} does not exist!" - LG = k2.Fsa.from_dict(torch.load(f"{lang_dir}/LG.pt", map_location="cpu")) + generate_corpus() + generate_arpa() + generate_G() + with open(G_fst_txt) as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + del G.aux_labels + G.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt") + G.draw("G.pdf", title="G") - sp = spm.SentencePieceProcessor() - sp.load(f"{lang_dir}/bpe.model") + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt") + L.aux_labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt") - word_table = k2.SymbolTable.from_file(f"{lang_dir}/words.txt") - s = "HELLO WORLD" - token_ids = sp.encode(s) + L = k2.arc_sort(L) + G = k2.arc_sort(G) - token_fsa = k2.linear_fsa(token_ids) + LG = k2.compose(L, G) + del LG.aux_labels - fsa = k2.intersect(LG, token_fsa) - fsa = k2.connect(fsa) - print(k2.to_dot(fsa)) - print(fsa.properties_str) + LG = k2.determinize(LG) + LG = k2.connect(LG) + LG = k2.arc_sort(LG) print(LG.properties_str) - # You can use https://dreampuf.github.io/GraphvizOnline/ - # to visualize the output. - # - # You can see that the resulting fsa is not deterministic - # Note: LG is non-deterministic - # - # See https://shorturl.at/uIL69 - # for visualization of the above fsa. + LG.draw("LG.pdf", title="LG") + # You can have a look at G.pdf and LG.pdf to get a feel + # what they look like if __name__ == "__main__": diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index a390c3180..ecddbf5a9 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -155,9 +155,14 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + + if True: + old_hyp.log_prob = torch.logaddexp( + old_hyp.log_prob, hyp.log_prob + ) + else: + old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) + if hyp.ngram_state_and_scores is not None: for state, score in hyp.ngram_state_and_scores.items(): if ( @@ -337,6 +342,7 @@ def modified_beam_search( encoder_out: torch.Tensor, beam: int = 4, LG: Optional[k2.Fsa] = None, + ngram_lm_scale: float = 0.1, ) -> List[int]: """It limits the maximum number of symbols per frame to 1. @@ -349,11 +355,13 @@ def modified_beam_search( Beam size. LG: Optional. Used for shallow fusion. + ngram_lm_scale: + Used only when LG is not None. The total score of a path is + am_score + ngram_lm_scale * ngram_lm_scale Returns: Return the decoded result. """ enable_shallow_fusion = LG is not None - ngram_lm_scale = 0.8 assert encoder_out.ndim == 3 @@ -422,6 +430,7 @@ def modified_beam_search( encoder_out_len.expand(decoder_out.size(0)), decoder_out_len.expand(decoder_out.size(0)), ) + vocab_size = logits.size(-1) # logits is of shape (num_hyps, vocab_size) log_probs = logits.log_softmax(dim=-1) @@ -437,6 +446,9 @@ def modified_beam_search( topk_hyp_indexes = topk_hyp_indexes.tolist() topk_token_indexes = topk_token_indexes.tolist() + # import pdb + # + # pdb.set_trace() for i in range(len(topk_hyp_indexes)): hyp = A[topk_hyp_indexes[i]] new_ys = hyp.ys[:] @@ -450,12 +462,15 @@ def modified_beam_search( if enable_shallow_fusion and new_token != blank_id: ngram_state_and_scores = shallow_fusion( - LG, new_token, hyp.ngram_state_and_scores + LG, + new_token, + hyp.ngram_state_and_scores, + vocab_size, ) if len(ngram_state_and_scores) == 0: continue max_ngram_score = max(ngram_state_and_scores.values()) - new_log_prob += ngram_lm_scale * max_ngram_score + new_log_prob = new_log_prob + ngram_lm_scale * max_ngram_score # TODO: Get the maximum scores in ngram_state_and_scores # and add it to new_log_prob @@ -468,6 +483,9 @@ def modified_beam_search( B.add(new_hyp) if len(B) == 0: + import logging + + logging.info("\n*****\nEmpty states!\n***\n") for h in A: B.add(h) diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index d4da693a5..b70b97d70 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -139,6 +139,15 @@ def get_parser(): Used only when --decoding-method is modified_beam_search.""", ) + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.1, + help="""Used when only --LG is provided. + The total score of a path is am_score + ngram_lm_scale * ngram_lm_score. + """, + ) + return parser @@ -279,14 +288,18 @@ def decode_one_batch( encoder_out=encoder_out_i, beam=params.beam_size, LG=LG, + ngram_lm_scale=params.ngram_lm_scale, ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) hyps.append(sp.decode(hyp).split()) + s = "\n" for h in hyps: - print(" ".join(h)) + s += " ".join(h) + s += "\n" + logging.info(s) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} @@ -336,6 +349,8 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): + if batch_idx > 10: + break texts = batch["supervisions"]["text"] hyps_dict = decode_one_batch( @@ -452,9 +467,10 @@ def main(): logging.info(f"LG properties: {LG.properties_str}") logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}") # If LG is created by local/compile_lg.py, then it should be epsilon - # free as well as arc sorted + # free, deterministic, and arc sorted assert "ArcSorted" in LG.properties_str assert "EpsilonFree" in LG.properties_str + assert "Deterministic" in LG.properties_str else: LG = None @@ -501,6 +517,8 @@ def main(): test_dl = [test_clean_dl, test_other_dl] for test_set, test_dl in zip(test_sets, test_dl): + if test_set == "test-other": + break results_dict = decode_dataset( dl=test_dl, params=params, diff --git a/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py b/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py index 18ef253dd..8f1045d45 100644 --- a/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py +++ b/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py @@ -18,61 +18,92 @@ from typing import Dict import k2 import torch +import copy def shallow_fusion( LG: k2.Fsa, token: int, state_and_scores: Dict[int, torch.Tensor], + vocab_size: int, ) -> Dict[int, torch.Tensor]: """ Args: LG: - An n-gram. It should be arc sorted and epsilon free. + An n-gram. It should be arc sorted, deterministic, and epsilon free. token: The input token ID. state_and_scores: The keys contain the current state we are in and the values are the LM log_prob for reaching the corresponding states from the start state. + vocab_size: + Vocabulary size, including the blank symbol. We assume that + token IDs >= vocab_size are disambig IDs (including the backoff + symbol #0). Returns: Return a new state_and_scores. """ row_splits = LG.arcs.row_splits(1) arcs = LG.arcs.values() + state_and_scores = copy.deepcopy(state_and_scores) + current_states = list(state_and_scores.keys()) + # Process out-going arcs with label being disambig tokens and #0 + while len(current_states) > 0: + s = current_states.pop() + labels_begin = row_splits[s] + labels_end = row_splits[s + 1] + labels = LG.labels[labels_begin:labels_end].contiguous() + + for i in reversed(range(labels.numel())): + lab = labels[i] + if lab == -1: + # Note: When sorting arcs, k2 treats arc labels as + # unsigned types + continue + + if lab < vocab_size: + # Since LG is arc sorted, we can exit + # the for loop as soon as we have a label + # with ID less than vocab_size + break + + # This is a diambig token or #0 + idx = labels_begin + i + next_state = arcs[idx][1].item() + score = LG.scores[idx] + state_and_scores[s] + if next_state not in state_and_scores: + state_and_scores[next_state] = score + current_states.append(next_state) + else: + state_and_scores[next_state] = max( + score, state_and_scores[next_state] + ) + + current_states = list(state_and_scores.keys()) ans = dict() for s in current_states: labels_begin = row_splits[s] labels_end = row_splits[s + 1] labels = LG.labels[labels_begin:labels_end].contiguous() - # As LG is not deterministic, there may be multiple - # out-going arcs that with label equal to "token" - # - # Note: LG is arc sorted! - left = torch.bucketize(token, labels, right=False) - right = torch.bucketize(token, labels, right=True) + if labels[-1] == -1: + labels = labels[:-1] - if left >= right: - # There are no out-going arcs from this state - # that have label equal to "token" + pos = torch.searchsorted(labels, token) + if pos >= labels.numel() or labels[pos] != token: continue - # Now we have - # labels[i] == token - # for - # left <= i < right + idx = labels_begin + pos + next_state = arcs[idx][1].item() + score = LG.scores[idx] + state_and_scores[s] - for i in range(left, right): - i += labels_begin - next_state = arcs[i][1].item() - score = LG.scores[i] - if next_state not in ans: - ans[next_state] = score - else: - ans[next_state] = max(score, ans[next_state]) + if next_state not in ans: + ans[next_state] = score + else: + ans[next_state] = max(score, ans[next_state]) return ans