From 07587d106a65558dc539f38f826219c28b8c9665 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 5 May 2023 14:42:47 +0800 Subject: [PATCH] fix bugs --- .../beam_search.py | 49 +++--- .../pruned_transducer_stateless4/decode.py | 19 ++- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 8 + .../pruned_transducer_stateless5/decode.py | 11 +- icefall/context_graph.py | 155 +++++++++++++++--- 5 files changed, 186 insertions(+), 56 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 51a321572..9655a23f1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import warnings from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Union @@ -750,6 +751,9 @@ class Hypothesis: """Return a string representation of self.ys""" return "_".join(map(str, self.ys)) + def __str__(self) -> str: + return f"ys: {'_'.join([str(i) for i in self.ys])}, log_prob: {float(self.log_prob):.2f}, state: {self.context_state}" + class HypothesisList(object): def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: @@ -887,6 +891,7 @@ def modified_beam_search( encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, context_graph: Optional[ContextGraph] = None, + num_context_history: int = 1, beam: int = 4, temperature: float = 1.0, return_timestamps: bool = False, @@ -938,7 +943,7 @@ def modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - context_state=ContextState(state_id=0), + context_state=None if context_graph is None else ContextState(graph=context_graph, max_states=num_context_history), timestamp=[], ) ) @@ -961,6 +966,7 @@ def modified_beam_search( hyps_shape = get_hyps_shape(B).to(device) A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] ys_log_probs = torch.cat( @@ -1018,30 +1024,24 @@ def modified_beam_search( for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] hyp = A[i][hyp_idx] - new_ys = hyp.ys[:] new_token = topk_token_indexes[k] new_timestamp = hyp.timestamp[:] - new_context_state = None + context_score = 0 + new_context_state = None if context_graph is None else hyp.context_state.clone() if new_token not in (blank_id, unk_id): new_ys.append(new_token) new_timestamp.append(t) if context_graph is not None: - new_context_state = context_graph.get_next_state( - hyp.context_state.state_id, new_token - ) - new_log_prob = topk_log_probs[k] + ( - 0 - if new_context_state is None - else new_context_state.score - ) + context_score, new_context_state = hyp.context_state.forward_one_step(new_token) + + new_log_prob = topk_log_probs[k] + context_score + new_hyp = Hypothesis( ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp, - context_state=hyp.context_state - if new_context_state is None - else new_context_state, + context_state=new_context_state, ) B[i].add(new_hyp) @@ -1053,20 +1053,15 @@ def modified_beam_search( finalized_B = [HypothesisList() for _ in range(len(B))] for i, hyps in enumerate(B): for hyp in list(hyps): - if hyp.context_state.state_id != 0: - new_context_state = context_graph.get_next_state( - hyp.context_state.state_id, 0 + context_score, new_context_state = hyp.context_state.finalize() + finalized_B[i].add( + Hypothesis( + ys=hyp.ys, + log_prob=hyp.log_prob + context_score, + timestamp=hyp.timestamp, + context_state=new_context_state, ) - finalized_B[i].add( - Hypothesis( - ys=hyp.ys, - log_prob=hyp.log_prob + new_context_state.score, - timestamp=hyp.timestamp, - context_state=new_context_state, - ) - ) - else: - finalized_B[i].add(hyp) + ) B = finalized_B best_hyps = [b.get_most_probable(length_norm=True) for b in B] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index ea1ae49cd..f99b5f54c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -131,6 +131,8 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 +import kaldifst +import graphviz import sentencepiece as spm import torch import torch.nn as nn @@ -363,6 +365,13 @@ def get_parser(): help="", ) + parser.add_argument( + "--num-context-history", + type=int, + default=1, + help="", + ) + parser.add_argument( "--context-file", type=str, @@ -511,6 +520,7 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, context_graph=context_graph, + num_context_history=params.num_context_history, return_timestamps=True, ) else: @@ -565,7 +575,10 @@ def decode_one_batch( return {key: (hyps, timestamps)} else: - return {f"beam_size_{params.beam_size}": (hyps, timestamps)} + key = f"beam_size_{params.beam_size}" + key += f"-context-score-{params.context_score}" + key += f"-num-context-history-{params.num_context_history}" + return {key: (hyps, timestamps)} def decode_dataset( @@ -614,7 +627,7 @@ def decode_dataset( if params.decoding_method == "greedy_search": log_interval = 50 else: - log_interval = 20 + log_interval = 1 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -769,6 +782,8 @@ def main(): params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += f"-context-score-{params.context_score}" + params.suffix += f"-num-context-history-{params.num_context_history}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 41698d00a..c60fce597 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -381,6 +381,7 @@ class LibriSpeechAsrDataModule: ) sampler = DynamicBucketingSampler( cuts, + num_buckets=2, max_duration=self.args.max_duration, shuffle=False, ) @@ -452,6 +453,13 @@ class LibriSpeechAsrDataModule: self.args.manifest_dir / "libri_books_feats.jsonl.gz" ) + @lru_cache() + def test_book_test_cuts(self) -> CutSet: + logging.info("About to get test-books cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libri_book_test_feats.jsonl.gz" + ) + @lru_cache() def test_book2_cuts(self) -> CutSet: logging.info("About to get test-books2 cuts") diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index f7cd4cbef..324c39ce5 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -287,6 +287,13 @@ def get_parser(): help="", ) + parser.add_argument( + "--num-context-history", + type=int, + default=1, + help="", + ) + parser.add_argument( "--context-file", type=str, @@ -389,6 +396,7 @@ def decode_one_batch( beam=params.beam_size, encoder_out_lens=encoder_out_lens, context_graph=context_graph, + num_context_history=params.num_context_history, ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) @@ -429,7 +437,7 @@ def decode_one_batch( } else: return { - f"beam_size_{params.beam_size}_context_score_{params.context_score}": hyps + f"beam_size_{params.beam_size}_context_score_{params.context_score}_num_context_history_{params.num_context_history}": hyps } @@ -568,6 +576,7 @@ def main(): elif "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-context-score-{params.context_score}" + params.suffix += f"-num-context-history-{params.num_context_history}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/icefall/context_graph.py b/icefall/context_graph.py index 9f4a26891..6cbcd043a 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -14,11 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import logging +from heapq import heappush, heappop import re from dataclasses import dataclass -from typing import List +from typing import List, Tuple import argparse import k2 import kaldifst @@ -27,17 +26,13 @@ import sentencepiece as spm from icefall.utils import is_module_available -@dataclass -class ContextState: - state_id: int = 0 - score: float = 0.0 - - class ContextGraph: def __init__(self, context_score: float = 1): self.context_score = context_score - def build_context_graph_char(self, contexts: List[str], token_table: k2.SymbolTable): + def build_context_graph_char( + self, contexts: List[str], token_table: k2.SymbolTable + ): """Convert a list of texts to a list-of-list of token IDs. Args: @@ -56,7 +51,7 @@ class ContextGraph: whitespace = re.compile(r"([ \t])") for text in contexts: text = re.sub(whitespace, "", text) - sub_ids : List[int] = [] + sub_ids: List[int] = [] skip = False for txt in text: if txt not in token_table: @@ -69,7 +64,9 @@ class ContextGraph: ids.append(sub_ids) self.build_context_graph(ids) - def build_context_graph_bpe(self, contexts: List[str], sp: spm.SentencePieceProcessor): + def build_context_graph_bpe( + self, contexts: List[str], sp: spm.SentencePieceProcessor + ): contexts_bpe = sp.encode(contexts) self.build_context_graph(contexts_bpe) @@ -80,7 +77,7 @@ class ContextGraph: ) # 1st state will be state 0 (returned by add_state) assert start_state == 0, start_state graph.start = 0 # set the start state to 0 - graph.set_final(start_state, weight=0) # weight is in log space + graph.set_final(start_state, weight=kaldifst.TropicalWeight.one) for tokens in token_ids: prev_state = start_state @@ -111,22 +108,128 @@ class ContextGraph: prev_state = next_state backoff_score += score self.graph = kaldifst.determinize(graph) + kaldifst.arcsort(self.graph) - def get_next_state(self, state_id: int, label: int) -> ContextState: - next_state = 0 - score = 0 - for arc in kaldifst.ArcIterator(self.graph, state_id): - if arc.ilabel == 0: - score = arc.weight.value - elif arc.ilabel == label: - next_state = arc.nextstate - score = arc.weight.value - break - return ContextState( - state_id=next_state, - score=score, + def is_final_state(self, state_id: int) -> bool: + return self.graph.final(state_id) == kaldifst.TropicalWeight.one + + + def get_next_state(self, state_id: int, label: int) -> Tuple[int, float, bool]: + arc_iter = kaldifst.ArcIterator(self.graph, state_id) + num_arcs = self.graph.num_arcs(state_id) + + # The LM is arc sorted by ilabel, so we use binary search below. + left = 0 + right = num_arcs - 1 + while left <= right: + mid = (left + right) // 2 + arc_iter.seek(mid) + arc = arc_iter.value + if arc.ilabel < label: + left = mid + 1 + elif arc.ilabel > label: + right = mid - 1 + else: + return (arc.nextstate, arc.weight.value, True) + + # Backoff to state 0 with the score on epsilon arc (ilabel == 0) + arc_iter.seek(0) + arc = arc_iter.value + if arc.ilabel == 0: + return (0, 0, False) + else: + return (0, arc.weight.value, False) + + +class ContextState: + def __init__(self, graph: ContextGraph, max_states: int): + self.graph = graph + self.max_states = max_states + # [(total score, (score, state_id))] + self.states: List[Tuple[float, Tuple[float, int]]] = [] + + def __str__(self): + return ";".join([str(state) for state in self.states]) + + def clone(self): + new_context_state = ContextState(graph=self.graph, max_states=self.max_states) + new_context_state.states = self.states[:] + return new_context_state + + def finalize(self) -> float: + new_context_state = ContextState(graph=self.graph, max_states=self.max_states) + if len(self.states) == 0: + return 0, new_context_state + item = heappop(self.states) + return item[0], new_context_state + + def forward_one_step(self, label: int) -> float: + states = self.states[:] + new_states = [] + # expand current label from state state + status = self.graph.get_next_state(0, label) + if status[2]: + heappush(new_states, (-status[1], (status[1], status[0]))) + else: + assert status[0] == 0 and status[2] == False, status + + # the score we have added to the path till now + prev_max_total_score = 0 + # expand previous states with given label + while states: + state = heappop(states) + if -state[0] > prev_max_total_score: + prev_max_total_score = -state[0] + + status = self.graph.get_next_state(state[1][1], label) + + if status[2]: + heappush(new_states, (state[0] - status[1], (status[1], status[0]))) + else: + pass + # assert status == (0, state[0], False), status + num_states_drop = ( + 0 + if len(new_states) <= self.max_states + else len(new_states) - self.max_states ) + states = [] + if len(new_states) == 0: + new_context_state = ContextState(graph=self.graph, max_states=self.max_states) + return -prev_max_total_score, new_context_state + + item = heappop(new_states) + + # if one item match a context, clear all states (means start a new context + # from next label), and return the score of current label + if self.graph.is_final_state(item[1][1]): + new_context_state = ContextState(graph=self.graph, max_states=self.max_states) + return -item[0] - prev_max_total_score, new_context_state + + max_total_score = -item[0] + heappush(states, item) + + while num_states_drop != 0: + item = heappop(new_states) + if self.graph.is_final_state(item[1][1]): + new_context_state = ContextState(graph=self.graph, max_states=self.max_states) + return -item[0] - prev_max_total_score, new_context_state + num_states_drop -= 1 + + while new_states: + item = heappop(new_states) + if self.graph.is_final_state(item[1][1]): + new_context_state = ContextState(graph=self.graph, max_states=self.max_states) + return -item[0] - prev_max_total_score, new_context_state + heappush(states, item) + # no context matched, the matching may continue with previous prefix, + # or change to another prefix. + new_context_state = ContextState(graph=self.graph, max_states=self.max_states) + new_context_state.states = states + return max_total_score - prev_max_total_score, new_context_state + + if __name__ == "__main__": parser = argparse.ArgumentParser()