diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index c60fce597..5b5e9b1cf 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -449,9 +449,7 @@ class LibriSpeechAsrDataModule: @lru_cache() def test_book_cuts(self) -> CutSet: logging.info("About to get test-books cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libri_books_feats.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "libri_books_feats.jsonl.gz") @lru_cache() def test_book_test_cuts(self) -> CutSet: diff --git a/icefall/context_graph.py b/icefall/context_graph.py index 61eb5090c..c78de30f6 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -61,6 +61,7 @@ class ContextState: self.is_end = is_end self.next = {} self.fail = None + self.output = None class ContextGraph: @@ -97,7 +98,7 @@ class ContextGraph: ) self.root.fail = self.root - def _fill_fail(self): + def _fill_fail_output(self): """This function fills the fail arc for each trie node, it can be computed in linear time by performing a breadth-first search starting from the root. See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the @@ -122,6 +123,14 @@ class ContextGraph: if token in fail.next: fail = fail.next[token] node.fail = fail + # fill the output arc + output = node.fail + while not output.is_end: + output = output.fail + if output.token == -1: # root + output = None + break + node.output = output queue.append(node) def build(self, token_ids: List[List[int]]): @@ -155,7 +164,7 @@ class ContextGraph: is_end=is_end, ) node = node.next[token] - self._fill_fail() + self._fill_fail_output() def forward_one_step( self, state: ContextState, token: int @@ -171,11 +180,14 @@ class ContextGraph: Returns: Return a tuple of score and next state. """ + node = None + score = 0 # token matched if token in state.next: node = state.next[token] score = node.token_score - return (score, node) + if state.is_end: + score += state.node_score else: # token not matched # We will trace along the fail arc until it matches the token or reaching @@ -191,7 +203,13 @@ class ContextGraph: # The score of the fail path score = node.node_score - state.local_node_score - return (score, node) + assert node is not None + matched_score = 0 + output = node.output + while output is not None: + matched_score += output.node_score + output = output.output + return (score + matched_score, node) def finalize(self, state: ContextState) -> Tuple[float, ContextState]: """When reaching the end of the decoded sequence, we need to finalize @@ -208,7 +226,9 @@ class ContextGraph: to root. The next state is always root. """ # The score of the fail arc - score = self.root.node_score - state.local_node_score + score = -state.node_score + if state.is_end: + score = 0 return (score, self.root) def draw( @@ -279,7 +299,7 @@ class ContextGraph: queue.append(self.root) # root id is always 0 dot.node("0", label="0", **default_node_attr) - dot.edge("0", "0", label=f"*/0") + dot.edge("0", "0", color="red") seen.add(0) while len(queue): @@ -304,6 +324,12 @@ class ContextGraph: str(node.fail.id), color="red", ) + if node.output is not None: + dot.edge( + str(node.id), + str(node.output.id), + color="green", + ) queue.append(node) if filename: @@ -331,7 +357,17 @@ class ContextGraph: if __name__ == "__main__": - contexts_str = ["HE", "SHE", "SHELL", "HIS", "HERS", "HELLO"] + contexts_str = [ + "S", + "HE", + "SHE", + "SHELL", + "HIS", + "HERS", + "HELLO", + "THIS", + "THEM", + ] contexts = [] for s in contexts_str: contexts.append([ord(x) for x in s]) @@ -350,9 +386,17 @@ if __name__ == "__main__": symbol_table=symbol_table, ) - queries = ["HERSHE", "HISHE", "SHED", "HELL", "HELLO", "DHRHISQ"] - expected_scores = [7, 6, 3, 2, 5, 3] - for i, query in enumerate(queries): + queries = { + "HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE" + "HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE" + "HISHE": 9, # "HIS", "S", "SHE", "HE" + "SHED": 6, # "S", "SHE", "HE" + "HELL": 2, # "HE" + "HELLO": 7, # "HE", "HELLO" + "DHRHISQ": 4, # "HIS", "S" + "THEN": 2, # "HE" + } + for query, expected_score in queries.items(): total_scores = 0 state = context_graph.root for q in query: @@ -361,8 +405,8 @@ if __name__ == "__main__": score, state = context_graph.finalize(state) assert state.token == -1, state.token total_scores += score - assert total_scores == expected_scores[i], ( + assert total_scores == expected_score, ( total_scores, - expected_scores[i], + expected_score, query, )