From 02e00ff504f3aaf72bfed427776ec5a4fecfcf97 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 27 Sep 2024 19:31:54 +0800 Subject: [PATCH] Add librispeech prefix-beam-search --- egs/librispeech/ASR/zipformer/ctc_decode.py | 140 ++++++++-- icefall/decode.py | 276 +++++++++++++++++--- icefall/utils.py | 11 + 3 files changed, 366 insertions(+), 61 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 9db429959..8f3dd10d2 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -123,6 +123,10 @@ from asr_datamodule import LibriSpeechAsrDataModule from lhotse import set_caching_enabled from train import add_model_arguments, get_model, get_params +from icefall.context_graph import ContextGraph, ContextState +from icefall.ngram_lm import NgramLm, NgramLmStateCost +from icefall.lm_wrapper import LmScorer + from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -131,6 +135,9 @@ from icefall.checkpoint import ( ) from icefall.decode import ( ctc_greedy_search, + ctc_prefix_beam_search, + ctc_prefix_beam_search_attention_decoder_rescoring, + ctc_prefix_beam_search_shallow_fussion, get_lattice, nbest_decoding, nbest_oracle, @@ -280,6 +287,23 @@ def get_parser(): """, ) + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + parser.add_argument( "--hlg-scale", type=float, @@ -301,7 +325,7 @@ def get_parser(): "--skip-scoring", type=str2bool, default=False, - help="""Skip scoring, but still save the ASR output (for eval sets).""" + help="""Skip scoring, but still save the ASR output (for eval sets).""", ) add_model_arguments(parser) @@ -314,8 +338,9 @@ def get_decoding_params() -> AttributeDict: params = AttributeDict( { "frame_shift_ms": 10, - "search_beam": 20, - "output_beam": 8, + "search_beam": 20, # for k2 fsa composition + "output_beam": 8, # for k2 fsa composition + "beam": 4, # for prefix-beam-search "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, @@ -333,6 +358,7 @@ def decode_one_batch( batch: dict, word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -377,10 +403,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. Note: If it decodes to nothing, then return None. """ - if HLG is not None: - device = HLG.device - else: - device = H.device + device = params.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) @@ -411,6 +434,48 @@ def decode_one_batch( key = "ctc-greedy-search" return {key: hyps} + if params.decoding_method == "prefix-beam-search": + token_ids = ctc_prefix_beam_search( + ctc_output=ctc_output, encoder_out_lens=encoder_out_lens + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search" + return {key: hyps} + + if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring": + best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring( + ctc_output=ctc_output, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + ans = dict() + for a_scale_str, token_ids in best_path_dict.items(): + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + token_ids = ctc_prefix_beam_search_shallow_fussion( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + LM=LM, + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search-shallow-fussion" + return {key: hyps} + supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -584,6 +649,7 @@ def decode_dataset( bpe_model: Optional[spm.SentencePieceProcessor], word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -634,6 +700,7 @@ def decode_dataset( batch=batch, word_table=word_table, G=G, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -664,9 +731,7 @@ def save_asr_output( """ for key, results in results_dict.items(): - recogs_filename = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recogs_filename, texts=results) @@ -680,7 +745,8 @@ def save_wer_results( results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): if params.decoding_method in ( - "attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring" + "attention-decoder-rescoring-with-ngram", + "whole-lattice-rescoring", ): # Set it to False since there are too many logs. enable_log = False @@ -721,6 +787,7 @@ def save_wer_results( def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) @@ -735,8 +802,11 @@ def main(): set_caching_enabled(True) # lhotse assert params.decoding_method in ( - "ctc-greedy-search", "ctc-decoding", + "ctc-greedy-search", + "prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", "1best", "nbest", "nbest-rescoring", @@ -762,6 +832,11 @@ def main(): params.suffix += f"_chunk-{params.chunk_size}" params.suffix += f"_left-context-{params.left_context_frames}" + if "prefix-beam-search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + params.suffix += f"_lm-scale-{params.lm_scale}" + if params.use_averaged_model: params.suffix += "_use-averaged-model" @@ -772,6 +847,8 @@ def main(): if torch.cuda.is_available(): device = torch.device("cuda", 0) + params.device = device + logging.info(f"Device: {device}") logging.info(params) @@ -786,14 +863,24 @@ def main(): params.sos_id = 1 if params.decoding_method in [ - "ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram" + "ctc-greedy-search", + "ctc-decoding", + "attention-decoder-rescoring-no-ngram", + "prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", ]: HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) + H = None + if params.decoding_method in [ + "ctc-decoding", + "attention-decoder-rescoring-no-ngram", + ]: + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) bpe_model = spm.SentencePieceProcessor() bpe_model.load(str(params.lang_dir / "bpe.model")) else: @@ -844,7 +931,8 @@ def main(): G = k2.Fsa.from_dict(d) if params.decoding_method in [ - "whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram" + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", ]: # Add epsilon self-loops to G as we will compose # it with the whole lattice later @@ -858,6 +946,19 @@ def main(): else: G = None + # only load the neural network LM if required + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + logging.info("About to create model") model = get_model(params) @@ -967,6 +1068,7 @@ def main(): bpe_model=bpe_model, word_table=lexicon.word_table, G=G, + LM=LM, ) save_asr_output( diff --git a/icefall/decode.py b/icefall/decode.py index 5ec9296e1..d23ce2ebb 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1511,30 +1511,43 @@ def ctc_greedy_search( class Hypothesis: # The predicted tokens so far. # Newly predicted tokens are appended to `ys`. - ys: List[int] + ys: List[int] = field(default_factory=list) # The log prob of ys. # It contains only one entry. - log_prob_blank: torch.Tensor + log_prob_blank: torch.Tensor = torch.zeros(1, dtype=torch.float32) - log_prob_non_blank: torch.Tensor + log_prob_non_blank: torch.Tensor = torch.tensor( + [float("-inf")], dtype=torch.float32 + ) # timestamp[i] is the frame index after subsampling # on which ys[i] is decoded timestamp: List[int] = field(default_factory=list) - # the lm score for next token given the current ys - lm_score: Optional[torch.Tensor] = None + # The lm score of ys + # It contains only one entry + lm_score: torch.Tensor = torch.zeros(1, dtype=torch.float32) + + # the lm log_probs for next token given the history ys + lm_log_probs: Optional[torch.Tensor] = None # the RNNLM states (h and c in LSTM) state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None # N-gram LM state - state_cost: Optional[NgramLmStateCost] = None + LODR_state: Optional[NgramLmStateCost] = None + + # N-gram LM state + Ngram_state: Optional[NgramLmStateCost] = None # Context graph state context_state: Optional[ContextState] = None + @property + def tot_score(self) -> torch.Tensor: + return self.log_prob + self.lm_score + @property def log_prob(self) -> torch.Tensor: return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank) @@ -1544,6 +1557,20 @@ class Hypothesis: """Return a tuple representation of self.ys""" return tuple(self.ys) + def clone(self) -> "Hypothesis": + return Hypothesis( + ys=self.ys, + log_prob_blank=self.log_prob_blank, + log_prob_non_blank=self.log_prob_non_blank, + timestamp=self.timestamp, + lm_log_probs=self.lm_log_probs, + lm_score=self.lm_score, + state=self.state, + LODR_state=self.LODR_state, + Ngram_state=self.Ngram_state, + context_state=self.context_state, + ) + class HypothesisList(object): def __init__(self, data: Optional[Dict[tuple, Hypothesis]] = None) -> None: @@ -1597,9 +1624,9 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + return max(self._data.values(), key=lambda hyp: hyp.tot_score / len(hyp.ys)) else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) + return max(self._data.values(), key=lambda hyp: hyp.tot_score) def remove(self, hyp: Hypothesis) -> None: """Remove a given hypothesis. @@ -1629,7 +1656,7 @@ class HypothesisList(object): """ ans = HypothesisList() for _, hyp in self._data.items(): - if hyp.log_prob > threshold: + if hyp.tot_score > threshold: ans.add(hyp) # shallow copy return ans @@ -1645,17 +1672,20 @@ class HypothesisList(object): if length_norm: hyps = sorted( - hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True + hyps, key=lambda h: h[1].tot_score / len(h[1].ys), reverse=True )[:k] else: - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + hyps = sorted(hyps, key=lambda h: h[1].tot_score, reverse=True)[:k] ans = HypothesisList(dict(hyps)) return ans - def __contains__(self, key: str): + def __contains__(self, key: tuple): return key in self._data + def __getitem__(self, key: tuple): + return self._data[key] + def __iter__(self): return iter(self._data.values()) @@ -1694,64 +1724,96 @@ def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: return ans -def _step_worker(log_probs, indexes, B, beam, blank_id): +def _step_worker( + log_probs, + indexes, + B, + beam, + blank_id, + lm_scale: float = 0, + LODR_lm_scale: float = 0, + context_graph: Optional[ContextGraph] = None, +): A = list(B) B = HypothesisList() for h in range(len(A)): hyp = A[h] for k in range(log_probs.size(0)): log_prob, index = log_probs[k], indexes[k] - if index == blank_id: + new_token = index.item() + update_prefix = False + new_hyp = hyp.clone() + if new_token == blank_id: # Case 0: *a + ε => *a # *aε + ε => *a # Prefix does not change, update log_prob of blank - new_hyp = Hypothesis( - ys=hyp.ys[:], - log_prob_non_blank=torch.tensor( - [float("-inf")], dtype=torch.float32 - ), - log_prob_blank=hyp.log_prob + log_prob, + new_hyp.log_prob_non_blank = torch.tensor( + [float("-inf")], dtype=torch.float32 ) + new_hyp.log_prob_blank = hyp.log_prob + log_prob B.add(new_hyp) - elif len(hyp.ys) > 0 and hyp.ys[-1] == index: + elif len(hyp.ys) > 0 and hyp.ys[-1] == new_token: # Case 1: *a + a => *a # Prefix does not change, update log_prob of non_blank - new_hyp = Hypothesis( - ys=hyp.ys[:], - log_prob_non_blank=hyp.log_prob_non_blank + log_prob, - log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32), + new_hyp.log_prob_non_blank = hyp.log_prob_non_blank + log_prob + new_hyp.log_prob_blank = torch.tensor( + [float("-inf")], dtype=torch.float32 ) B.add(new_hyp) # Case 2: *aε + a => *aa # Prefix changes, update log_prob of blank - new_hyp = Hypothesis( - ys=hyp.ys[:] + [index.item()], - log_prob_non_blank=hyp.log_prob_blank + log_prob, - log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32), + new_hyp = hyp.clone() + # Caution: DO NOT use append, as clone is shallow copy + new_hyp.ys = hyp.ys + [new_token] + new_hyp.log_prob_non_blank = hyp.log_prob_blank + log_prob + new_hyp.log_prob_blank = torch.tensor( + [float("-inf")], dtype=torch.float32 ) - B.add(new_hyp) + update_prefix = True else: # Case 3: *a + b => *ab, *aε + b => *ab # Prefix changes, update log_prob of non_blank - new_hyp = Hypothesis( - ys=hyp.ys[:] + [index.item()], - log_prob_non_blank=hyp.log_prob + log_prob, - log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32), + # Caution: DO NOT use append, as clone is shallow copy + new_hyp.ys = hyp.ys + [new_token] + new_hyp.log_prob_non_blank = hyp.log_prob + log_prob + new_hyp.log_prob_blank = torch.tensor( + [float("-inf")], dtype=torch.float32 ) + update_prefix = True + + if update_prefix: + lm_score = hyp.lm_score + if hyp.lm_log_probs is not None: + lm_score += hyp.lm_log_probs[new_token] * lm_scale + new_hyp.lm_log_probs = None + + if context_graph is not None and hyp.context_state is not None: + context_score, new_context_state = context_graph.forward_one_step( + hyp.context_state, new_token + ) + lm_score += context_score + new_hyp.context_state = new_context_state + + if hyp.LODR_state is not None: + state_cost = hyp.LODR_state.forward_one_step(new_token) + # calculate the score of the latest token + current_ngram_score = state_cost.lm_score - hyp.LODR_state.lm_score + assert current_ngram_score <= 0.0, ( + state_cost.lm_score, + hyp.LODR_state.lm_score, + ) + lm_score += LODR_lm_scale * current_ngram_score + new_hyp.LODR_state = state_cost + + new_hyp.lm_score = lm_score B.add(new_hyp) B = B.topk(beam) return B def _batch_worker(topk_values, topk_indexes, B, encoder_out_lens, beam, blank_id): - B.add( - Hypothesis( - ys=[], - log_prob_non_blank=torch.tensor([float("-inf")], dtype=torch.float32), - log_prob_blank=torch.zeros(1, dtype=torch.float32), - ) - ) + B.add(Hypothesis()) for j in range(encoder_out_lens): log_probs, indexes = topk_values[j], topk_indexes[j] B = _step_worker(log_probs, indexes, B, beam, blank_id) @@ -1763,11 +1825,11 @@ def ctc_prefix_beam_search( encoder_out_lens: torch.Tensor, beam: int = 4, blank_id: int = 0, - context_graph: Optional[ContextGraph] = None, process_pool: Optional[Pool] = None, return_nbest: Optional[bool] = False, ) -> Union[List[List[int]], List[HypothesisList]]: batch_size, num_frames, vocab_size = ctc_output.shape + # TODO: using a larger beam for first pass pruning topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam) topk_values = topk_values.cpu() @@ -1800,6 +1862,136 @@ def ctc_prefix_beam_search( return [hyp.ys for hyp in best_hyps] +def ctc_prefix_beam_search_shallow_fussion( + ctc_output: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: int = 4, + blank_id: int = 0, + LODR_lm: Optional[NgramLm] = None, + LODR_lm_scale: Optional[float] = 0, + LM: Optional[LmScorer] = None, + context_graph: Optional[ContextGraph] = None, +) -> List[List[int]]: + batch_size, num_frames, vocab_size = ctc_output.shape + # TODO: using a larger beam for first pass pruning + topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam) + topk_values = topk_values.cpu() + topk_indexes = topk_indexes.cpu() + encoder_out_lens = encoder_out_lens.tolist() + device = ctc_output.device + + lm_scale = 0 + init_scores = None + init_states = None + + if LM is not None: + lm_scale = LM.lm_scale + sos_id = getattr(LM, "sos_id", 1) + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + lens = torch.tensor([1]).to(device) + init_scores, init_states = LM.score_token(sos_token, lens) + init_scores, init_states = init_scores.cpu(), ( + init_states[0].cpu(), + init_states[1].cpu(), + ) + + B = [HypothesisList() for _ in range(batch_size)] + for i in range(batch_size): + B[i].add( + Hypothesis( + ys=[], + log_prob_non_blank=torch.tensor([float("-inf")], dtype=torch.float32), + log_prob_blank=torch.zeros(1, dtype=torch.float32), + lm_score=torch.zeros(1, dtype=torch.float32), + state=init_states, + lm_log_probs=None if init_scores is None else init_scores.reshape(-1), + LODR_state=None if LODR_lm is None else NgramLmStateCost(LODR_lm), + context_state=None if context_graph is None else context_graph.root, + ) + ) + for j in range(num_frames): + for i in range(batch_size): + if j < encoder_out_lens[i]: + log_probs, indexes = topk_values[i][j], topk_indexes[i][j] + B[i] = _step_worker( + log_probs, + indexes, + B[i], + beam, + blank_id, + lm_scale=lm_scale, + LODR_lm_scale=LODR_lm_scale, + context_graph=context_graph, + ) + if LM is None: + continue + # update lm_score + token_list = [] # a list of list + hs = [] + cs = [] + indexes = [] # (batch_idx, key) + for batch_idx, hyps in enumerate(B): + for hyp in hyps: + if hyp.lm_log_probs is None: + if LM.lm_type == "rnn": + token_list.append([hyp.ys[-1]]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append([sos_id] + hyp.ys[:]) + indexes.append((batch_idx, hyp.key)) + if len(token_list) != 0: + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) + scores, lm_states = scores.cpu(), (lm_states[0].cpu(), lm_states[1].cpu()) + assert scores.size(0) == len(indexes), (scores.size(0), len(indexes)) + for i in range(scores.size(0)): + batch_idx, key = indexes[i] + B[batch_idx][key].lm_log_probs = scores[i] + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, i, :].unsqueeze(1), + lm_states[1][:, i, :].unsqueeze(1), + ) + B[batch_idx][key].state = state + + # finalize context_state, if the matched contexts do not reach final state + # we need to add the score on the corresponding backoff arc + if context_graph is not None: + for hyps in B: + for hyp in hyps: + context_score, new_context_state = context_graph.finalize( + hyp.context_state + ) + hyp.lm_score += context_score + hyp.context_state = new_context_state + + best_hyps = [b.get_most_probable() for b in B] + return [hyp.ys for hyp in best_hyps] + + def ctc_prefix_beam_search_attention_decoder_rescoring( ctc_output: torch.Tensor, attention_decoder: torch.nn.Module, diff --git a/icefall/utils.py b/icefall/utils.py index 1dbb954de..1f72addf2 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -19,8 +19,10 @@ import argparse import collections +import json import logging import os +import pathlib import re import subprocess from collections import defaultdict @@ -178,6 +180,15 @@ class AttributeDict(dict): return raise AttributeError(f"No such attribute '{key}'") + def __str__(self, indent: int = 2): + tmp = {} + for k, v in self.items(): + # PosixPath is ont JSON serializable + if isinstance(v, pathlib.Path) or isinstance(v, torch.device): + v = str(v) + tmp[k] = v + return json.dumps(tmp, indent=indent, sort_keys=True) + def encode_supervisions( supervisions: dict,