From 0c096a9ab4dec3f29639f6a001907e01dc823b78 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 26 Sep 2024 15:22:29 +0800 Subject: [PATCH] add ctc prefix beam search --- egs/gigaspeech/ASR/zipformer/ctc_decode.py | 49 ++- icefall/decode.py | 371 ++++++++++++++++++++- 2 files changed, 409 insertions(+), 11 deletions(-) diff --git a/egs/gigaspeech/ASR/zipformer/ctc_decode.py b/egs/gigaspeech/ASR/zipformer/ctc_decode.py index 8f70d256b..a3405d4b9 100755 --- a/egs/gigaspeech/ASR/zipformer/ctc_decode.py +++ b/egs/gigaspeech/ASR/zipformer/ctc_decode.py @@ -124,7 +124,7 @@ from asr_datamodule import GigaSpeechAsrDataModule from gigaspeech_scoring import asr_text_post_processing from lhotse import set_caching_enabled -from train import add_model_arguments, get_model, get_params +from train_cr_aed import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, @@ -134,6 +134,7 @@ from icefall.checkpoint import ( ) from icefall.decode import ( ctc_greedy_search, + ctc_prefix_beam_search, get_lattice, nbest_decoding, nbest_oracle, @@ -327,6 +328,17 @@ def get_decoding_params() -> AttributeDict: return params +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -380,10 +392,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. Note: If it decodes to nothing, then return None. """ - if HLG is not None: - device = HLG.device - else: - device = H.device + device = params.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) @@ -414,6 +423,18 @@ def decode_one_batch( key = "ctc-greedy-search" return {key: hyps} + if params.decoding_method == "prefix-beam-search": + token_ids = ctc_prefix_beam_search( + ctc_output=ctc_output, encoder_out_lens=encoder_out_lens, beam=8 + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search" + return {key: hyps} + supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -738,6 +759,7 @@ def main(): assert params.decoding_method in ( "ctc-greedy-search", + "prefix-beam-search", "ctc-decoding", "1best", "nbest", @@ -773,6 +795,7 @@ def main(): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) + params.device = device logging.info(f"Device: {device}") logging.info(params) @@ -790,14 +813,20 @@ def main(): if params.decoding_method in [ "ctc-greedy-search", "ctc-decoding", + "prefix-beam-search", "attention-decoder-rescoring-no-ngram", ]: HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) + H = None + if params.decoding_method in [ + "ctc-decoding", + "attention-decoder-rescoring-no-ngram", + ]: + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) bpe_model = spm.SentencePieceProcessor() bpe_model.load(str(params.lang_dir / "bpe.model")) else: diff --git a/icefall/decode.py b/icefall/decode.py index dd3af1e99..addbc3ff7 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -15,11 +15,18 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Union import k2 import torch +from multiprocessing.pool import Pool + +from icefall.context_graph import ContextGraph, ContextState +from icefall.ngram_lm import NgramLm, NgramLmStateCost +from icefall.lm_wrapper import LmScorer + from icefall.utils import add_eos, add_sos, get_texts DEFAULT_LM_SCALE = [ @@ -1497,3 +1504,365 @@ def ctc_greedy_search( hyps = [h[h != blank_id].tolist() for h in hyps] return hyps + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] + + # The log prob of ys. + # It contains only one entry. + log_prob_blank: torch.Tensor + + log_prob_non_blank: torch.Tensor + + # timestamp[i] is the frame index after subsampling + # on which ys[i] is decoded + timestamp: List[int] = field(default_factory=list) + + # the lm score for next token given the current ys + lm_score: Optional[torch.Tensor] = None + + # the RNNLM states (h and c in LSTM) + state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + # N-gram LM state + state_cost: Optional[NgramLmStateCost] = None + + # Context graph state + context_state: Optional[ContextState] = None + + @property + def log_prob(self) -> torch.Tensor: + return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank) + + @property + def key(self) -> tuple: + """Return a tuple representation of self.ys""" + return tuple(self.ys) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[tuple, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[tuple, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + torch.logaddexp( + old_hyp.log_prob_blank, hyp.log_prob_blank, out=old_hyp.log_prob_blank + ) + torch.logaddexp( + old_hyp.log_prob_non_blank, + hyp.log_prob_non_blank, + out=old_hyp.log_prob_non_blank, + ) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + Returns: + Return the hypothesis that has the largest `log_prob`. + """ + if length_norm: + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + else: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Caution: + `self` is modified **in-place**. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: torch.Tensor) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + for _, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": + """Return the top-k hypothesis. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + """ + hyps = list(self._data.items()) + + if length_norm: + hyps = sorted( + hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True + )[:k] + else: + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(str(s)) + + +def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + +def _step_worker(log_probs, indexes, B, beam, blank_id): + A = list(B) + B = HypothesisList() + for h in range(len(A)): + hyp = A[h] + for k in range(log_probs.size(0)): + log_prob, index = log_probs[k], indexes[k] + if index == blank_id: + # Case 0: *a + ε => *a + # *aε + ε => *a + # Prefix does not change, update log_prob of blank + new_hyp = Hypothesis( + ys=hyp.ys[:], + log_prob_non_blank=torch.tensor( + [float("-inf")], dtype=torch.float32 + ), + log_prob_blank=hyp.log_prob + log_prob, + ) + B.add(new_hyp) + elif len(hyp.ys) > 0 and hyp.ys[-1] == index: + # Case 1: *a + a => *a + # Prefix does not change, update log_prob of non_blank + new_hyp = Hypothesis( + ys=hyp.ys[:], + log_prob_non_blank=hyp.log_prob_non_blank + log_prob, + log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32), + ) + B.add(new_hyp) + + # Case 2: *aε + a => *aa + # Prefix changes, update log_prob of blank + new_hyp = Hypothesis( + ys=hyp.ys[:] + [index.item()], + log_prob_non_blank=hyp.log_prob_blank + log_prob, + log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32), + ) + B.add(new_hyp) + else: + # Case 3: *a + b => *ab, *aε + b => *ab + # Prefix changes, update log_prob of non_blank + new_hyp = Hypothesis( + ys=hyp.ys[:] + [index.item()], + log_prob_non_blank=hyp.log_prob + log_prob, + log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32), + ) + B.add(new_hyp) + B = B.topk(beam) + return B + + +def _batch_worker(topk_values, topk_indexes, B, encoder_out_lens, beam, blank_id): + B.add( + Hypothesis( + ys=[], + log_prob_non_blank=torch.tensor([float("-inf")], dtype=torch.float32), + log_prob_blank=torch.zeros(1, dtype=torch.float32), + ) + ) + for j in range(encoder_out_lens): + log_probs, indexes = topk_values[j], topk_indexes[j] + B = _step_worker(log_probs, indexes, B, beam, blank_id) + return B + + +def ctc_prefix_beam_search( + ctc_output: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: int = 4, + blank_id: int = 0, + context_graph: Optional[ContextGraph] = None, + process_pool: Optional[Pool] = None, + return_nbest: Optional[bool] = False, +) -> Union[List[List[int]], List[HypothesisList]]: + batch_size, num_frames, vocab_size = ctc_output.shape + # TODO: using a larger beam for first pass pruning + topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam) + topk_values = topk_values.cpu() + topk_indexes = topk_indexes.cpu() + + B = [HypothesisList() for _ in range(batch_size)] + + pool = Pool() if process_pool is None else process_pool + arguments = [] + for i in range(batch_size): + arguments.append( + ( + topk_values[i], + topk_indexes[i], + B[i], + encoder_out_lens[i].item(), + beam, + blank_id, + ) + ) + async_results = pool.starmap_async(_batch_worker, arguments) + B = list(async_results.get()) + if process_pool is None: + pool.close() + pool.join() + if return_nbest: + return B + else: + best_hyps = [b.get_most_probable() for b in B] + return [hyp.ys for hyp in best_hyps] + + +def ctc_prefix_beam_search_attention_decoder_rescoring( + ctc_output: torch.Tensor, + attention_decoder: torch.nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: int = 8, + blank_id: int = 0, + attention_scale: Optional[float] = None, +): + # List[HypothesisList] + nbest = ctc_prefix_beam_search( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + beam=beam, + blank_id=blank_id, + return_nbest=True, + ) + + device = ctc_output.device + + hyp_shape = get_hyps_shape(nbest).to(device) + hyp_to_utt_map = hyp_shape.row_ids(1).to(torch.long) + # the shape of encoder_out is (N, T, C), so we use axis=0 here + expanded_encoder_out = encoder_out.index_select(0, hyp_to_utt_map) + expanded_encoder_out_lens = encoder_out_lens.index_select(0, hyp_to_utt_map) + + nbest = [list(x) for x in nbest] + token_ids = [] + scores = [] + for hyps in nbest: + for hyp in hyps: + token_ids.append(hyp.ys) + scores.append(hyp.log_prob.reshape(1)) + scores = torch.cat(scores).to(device) + + nll = attention_decoder.nll( + encoder_out=expanded_encoder_out, + encoder_out_lens=expanded_encoder_out_lens, + token_ids=token_ids, + ) + assert nll.ndim == 2 + assert nll.shape[0] == len(token_ids) + + attention_scores = -nll.sum(dim=1) + + if attention_scale is None: + attention_scale_list = [0.01, 0.05, 0.08] + attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] + attention_scale_list += [5.0, 6.0, 7.0, 8.0, 9.0] + else: + attention_scale_list = [attention_scale] + + ans = dict() + + start_indexes = hyp_shape.row_splits(1)[0:-1] + for a_scale in attention_scale_list: + tot_scores = scores + a_scale * attention_scores + ragged_tot_scores = k2.RaggedTensor(hyp_shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + max_indexes = max_indexes - start_indexes + max_indexes = max_indexes.cpu() + best_path = [nbest[i][max_indexes[i]].ys for i in range(len(max_indexes))] + key = f"attention_scale_{a_scale}" + ans[key] = best_path + return ans