diff --git a/icefall/decode.py b/icefall/decode.py index d23ce2ebb..777f9e3e8 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1513,10 +1513,12 @@ class Hypothesis: # Newly predicted tokens are appended to `ys`. ys: List[int] = field(default_factory=list) - # The log prob of ys. + # The log prob of ys that ends with blank token. # It contains only one entry. log_prob_blank: torch.Tensor = torch.zeros(1, dtype=torch.float32) + # The log prob of ys that ends with non blank token. + # It contains only one entry. log_prob_non_blank: torch.Tensor = torch.tensor( [float("-inf")], dtype=torch.float32 ) @@ -1526,16 +1528,18 @@ class Hypothesis: timestamp: List[int] = field(default_factory=list) # The lm score of ys + # May contain external LM score (including LODR score) and contextual biasing score # It contains only one entry lm_score: torch.Tensor = torch.zeros(1, dtype=torch.float32) # the lm log_probs for next token given the history ys + # The number of elements should be equal to vocabulary size. lm_log_probs: Optional[torch.Tensor] = None # the RNNLM states (h and c in LSTM) state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - # N-gram LM state + # LODR (N-gram LM) state LODR_state: Optional[NgramLmStateCost] = None # N-gram LM state @@ -1544,10 +1548,12 @@ class Hypothesis: # Context graph state context_state: Optional[ContextState] = None + # This is the total score of current path, acoustic plus external LM score. @property def tot_score(self) -> torch.Tensor: return self.log_prob + self.lm_score + # This is only the probability from model output (i.e External LM score not included). @property def log_prob(self) -> torch.Tensor: return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank) @@ -1614,14 +1620,14 @@ class HypothesisList(object): def get_most_probable(self, length_norm: bool = False) -> Hypothesis: """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. + the largest `tot_score`. Args: length_norm: - If True, the `log_prob` of a hypothesis is normalized by the + If True, the `tot_score` of a hypothesis is normalized by the number of tokens in it. Returns: - Return the hypothesis that has the largest `log_prob`. + Return the hypothesis that has the largest `tot_score`. """ if length_norm: return max(self._data.values(), key=lambda hyp: hyp.tot_score / len(hyp.ys)) @@ -1645,14 +1651,14 @@ class HypothesisList(object): del self._data[key] def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. + """Remove all Hypotheses whose tot_score is less than threshold. Caution: `self` is not modified. Instead, a new HypothesisList is returned. Returns: Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. + with `tot_score` being greater than the given `threshold`. """ ans = HypothesisList() for _, hyp in self._data.items(): @@ -1665,7 +1671,7 @@ class HypothesisList(object): Args: length_norm: - If True, the `log_prob` of a hypothesis is normalized by the + If True, the `tot_score` of a hypothesis is normalized by the number of tokens in it. """ hyps = list(self._data.items()) @@ -1725,15 +1731,39 @@ def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: def _step_worker( - log_probs, - indexes, - B, - beam, - blank_id, + log_probs: torch.Tensor, + indexes: torch.Tensor, + B: HypothesisList, + beam: int = 4, + blank_id: int = 0, lm_scale: float = 0, LODR_lm_scale: float = 0, context_graph: Optional[ContextGraph] = None, -): +) -> HypothesisList: + """The worker to decode one step. + + Args: + log_probs: + topk log_probs of current step (i.e. the kept tokens of first pass pruning), + the shape is (beam,) + topk_indexes: + The indexes of the topk_values above, the shape is (beam,) + B: + An instance of HypothesisList containing the kept hypothesis. + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + lm_scale: + The scale of nn lm. + LODR_lm_scale: + The scale of the LODR_lm + context_graph: + A ContextGraph instance containing contextual phrases. + + Return: + Returns the updated HypothesisList. + """ A = list(B) B = HypothesisList() for h in range(len(A)): @@ -1812,7 +1842,34 @@ def _step_worker( return B -def _batch_worker(topk_values, topk_indexes, B, encoder_out_lens, beam, blank_id): +def _sequence_worker( + topk_values: torch.Tensor, + topk_indexes: torch.Tensor, + B: HypothesisList, + encoder_out_lens: torch.Tensor, + beam: int = 4, + blank_id: int = 0, +) -> HypothesisList: + """The worker to decode one sequence. + + Args: + topk_values: + topk log_probs of model output (i.e. the kept tokens of first pass pruning), + the shape is (T, beam) + topk_indexes: + The indexes of the topk_values above, the shape is (T, beam) + B: + An instance of HypothesisList containing the kept hypothesis. + encoder_out_lens: + The lengths (frames) of sequences after subsampling, the shape is (B,) + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + + Return: + Returns the updated HypothesisList. + """ B.add(Hypothesis()) for j in range(encoder_out_lens): log_probs, indexes = topk_values[j], topk_indexes[j] @@ -1828,6 +1885,24 @@ def ctc_prefix_beam_search( process_pool: Optional[Pool] = None, return_nbest: Optional[bool] = False, ) -> Union[List[List[int]], List[HypothesisList]]: + """Implement prefix search decoding in "Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks". + + Args: + ctc_output: + The output of ctc head (log probability), the shape is (B, T, V) + encoder_out_lens: + The lengths (frames) of sequences after subsampling, the shape is (B,) + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + process_pool: + The process pool for parallel decoding, if not provided, it will use all + you cpu cores by default. + return_nbest: + If true, return a list of HypothesisList, return a list of list of decoded token ids otherwise. + """ batch_size, num_frames, vocab_size = ctc_output.shape # TODO: using a larger beam for first pass pruning @@ -1850,7 +1925,7 @@ def ctc_prefix_beam_search( blank_id, ) ) - async_results = pool.starmap_async(_batch_worker, arguments) + async_results = pool.starmap_async(_sequence_worker, arguments) B = list(async_results.get()) if process_pool is None: pool.close() @@ -1872,6 +1947,32 @@ def ctc_prefix_beam_search_shallow_fussion( LM: Optional[LmScorer] = None, context_graph: Optional[ContextGraph] = None, ) -> List[List[int]]: + """Implement prefix search decoding in "Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add + nervous language model shallow fussion, it also supports contextual + biasing with a given grammar. + + Args: + ctc_output: + The output of ctc head (log probability), the shape is (B, T, V) + encoder_out_lens: + The lengths (frames) of sequences after subsampling, the shape is (B,) + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + LODR_lm: + A low order n-gram LM, whose score will be subtracted during shallow fusion + LODR_lm_scale: + The scale of the LODR_lm + LM: + A neural net LM, e.g an RNNLM or transformer LM + context_graph: + A ContextGraph instance containing contextual phrases. + + Return: + Returns a list of list of decoded token ids. + """ batch_size, num_frames, vocab_size = ctc_output.shape # TODO: using a larger beam for first pass pruning topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam) @@ -1926,14 +2027,14 @@ def ctc_prefix_beam_search_shallow_fussion( ) if LM is None: continue - # update lm_score + # update lm_log_probs token_list = [] # a list of list hs = [] cs = [] indexes = [] # (batch_idx, key) for batch_idx, hyps in enumerate(B): for hyp in hyps: - if hyp.lm_log_probs is None: + if hyp.lm_log_probs is None: # those hyps that prefix changes if LM.lm_type == "rnn": token_list.append([hyp.ys[-1]]) # store the LSTM states @@ -2000,7 +2101,32 @@ def ctc_prefix_beam_search_attention_decoder_rescoring( beam: int = 8, blank_id: int = 0, attention_scale: Optional[float] = None, + process_pool: Optional[Pool] = None, ): + """Implement prefix search decoding in "Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add + attention decoder rescoring. + + Args: + ctc_output: + The output of ctc head (log probability), the shape is (B, T, V) + attention_decoder: + The attention decoder. + encoder_out: + The output of encoder, the shape is (B, T, D) + encoder_out_lens: + The lengths (frames) of sequences after subsampling, the shape is (B,) + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + attention_scale: + The scale of attention decoder score, if not provided it will search in + a default list (see the code below). + process_pool: + The process pool for parallel decoding, if not provided, it will use all + you cpu cores by default. + """ # List[HypothesisList] nbest = ctc_prefix_beam_search( ctc_output=ctc_output,