From adb54aea91abe211b19ec75eeb422b15a3867405 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 15 Feb 2022 12:33:53 +0800 Subject: [PATCH] Add backoff arcs to the start state to handle OOV word. --- .../ASR/transducer_stateless/beam_search.py | 343 ++++++++---------- .../ASR/transducer_stateless/decode.py | 40 +- .../transducer_stateless/shallow_fusion.py | 90 +++-- .../ASR/transducer_stateless/utils.py | 219 +++++++++++ 4 files changed, 459 insertions(+), 233 deletions(-) create mode 100644 egs/librispeech/ASR/transducer_stateless/utils.py diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index ecddbf5a9..088885f73 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Dict, List, Optional import k2 import torch from model import Transducer from shallow_fusion import shallow_fusion +from utils import Hypothesis, HypothesisList def greedy_search( @@ -103,153 +103,6 @@ def greedy_search( return hyp -@dataclass -class Hypothesis: - # The predicted tokens so far. - # Newly predicted tokens are appended to `ys`. - ys: List[int] - - # The log prob of ys. - # It contains only one entry. - log_prob: torch.Tensor - - # Used for shallow fusion - # The key of the dict is a state index into LG - # while the corresponding value is the LM score - # reaching this state. - # Note: The value tensor contains only a single entry - ngram_state_and_scores: Optional[Dict[int, torch.Tensor]] = (None,) - - @property - def key(self) -> str: - """Return a string representation of self.ys""" - return "_".join(map(str, self.ys)) - - -class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: - """ - Args: - data: - A dict of Hypotheses. Its key is its `value.key`. - """ - if data is None: - self._data = {} - else: - self._data = data - - @property - def data(self) -> Dict[str, Hypothesis]: - return self._data - - def add(self, hyp: Hypothesis) -> None: - """Add a Hypothesis to `self`. - - If `hyp` already exists in `self`, its probability is updated using - `log-sum-exp` with the existed one. - - Args: - hyp: - The hypothesis to be added. - """ - key = hyp.key - if key in self: - old_hyp = self._data[key] # shallow copy - - if True: - old_hyp.log_prob = torch.logaddexp( - old_hyp.log_prob, hyp.log_prob - ) - else: - old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) - - if hyp.ngram_state_and_scores is not None: - for state, score in hyp.ngram_state_and_scores.items(): - if ( - state in old_hyp.ngram_state_and_scores - and score > old_hyp.ngram_state_and_scores[state] - ): - old_hyp.ngram_state_and_scores[state] = score - else: - old_hyp.ngram_state_and_scores[state] = score - else: - self._data[key] = hyp - - def get_most_probable(self, length_norm: bool = False) -> Hypothesis: - """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - Returns: - Return the hypothesis that has the largest `log_prob`. - """ - if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) - else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) - - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Caution: - `self` is modified **in-place**. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for _, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" - hyps = list(self._data.items()) - - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] - - ans = HypothesisList(dict(hyps)) - return ans - - def __contains__(self, key: str): - return key in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self) -> int: - return len(self._data) - - def __str__(self) -> str: - s = [] - for key in self: - s.append(key) - return ", ".join(s) - - def run_decoder( ys: List[int], model: Transducer, @@ -341,6 +194,113 @@ def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + encoder_out_len = torch.tensor([1]) + decoder_out_len = torch.tensor([1]) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :] + # current_encoder_out is of shape (1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_output is of shape (num_hyps, 1, decoder_output_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, -1 + ) + + logits = model.joiner( + current_encoder_out, + decoder_out, + encoder_out_len.expand(decoder_out.size(0)), + decoder_out_len.expand(decoder_out.size(0)), + ) + # logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[i] + if new_token != blank_id: + new_ys.append(new_token) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + return ys + + +def modified_beam_search_with_shallow_fusion( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, LG: Optional[k2.Fsa] = None, ngram_lm_scale: float = 0.1, ) -> List[int]: @@ -408,7 +368,14 @@ def modified_beam_search( A = list(B) B = HypothesisList() - ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs contains both AM scores and LM scores + ys_log_probs = torch.cat( + [ + hyp.log_prob.reshape(1, 1) + + ngram_lm_scale * max(hyp.ngram_state_and_scores.values()) + for hyp in A + ] + ) # ys_log_probs is of shape (num_hyps, 1) decoder_input = torch.tensor( @@ -434,62 +401,52 @@ def modified_beam_search( # logits is of shape (num_hyps, vocab_size) log_probs = logits.log_softmax(dim=-1) - log_probs.add_(ys_log_probs) + tot_log_probs = log_probs + ys_log_probs - log_probs = log_probs.reshape(-1) - topk_log_probs, topk_indexes = log_probs.topk(beam) + _, topk_indexes = tot_log_probs.reshape(-1).topk(beam) + topk_log_probs = log_probs.reshape(-1)[topk_indexes] # topk_hyp_indexes are indexes into `A` topk_hyp_indexes = topk_indexes // logits.size(-1) topk_token_indexes = topk_indexes % logits.size(-1) - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() + topk_hyp_indexes, indexes = torch.sort(topk_hyp_indexes) + topk_token_indexes = topk_token_indexes[indexes] + topk_log_probs = topk_log_probs[indexes] - # import pdb - # - # pdb.set_trace() - for i in range(len(topk_hyp_indexes)): - hyp = A[topk_hyp_indexes[i]] - new_ys = hyp.ys[:] - new_token = topk_token_indexes[i] - if new_token != blank_id: - new_ys.append(new_token) - else: - ngram_state_and_scores = hyp.ngram_state_and_scores + shape = k2.ragged.create_ragged_shape2( + row_ids=topk_hyp_indexes.to(torch.int32), + cached_tot_size=topk_hyp_indexes.numel(), + ) + blank_log_probs = log_probs[topk_hyp_indexes, 0] - new_log_prob = topk_log_probs[i] + row_splits = shape.row_splits(1).tolist() + num_rows = len(row_splits) - 1 + for i in range(num_rows): + start = row_splits[i] + end = row_splits[i + 1] + if start >= end: + # Discard A[i] as other hyps have higher log_probs + continue + tokens = topk_token_indexes[start:end] - if enable_shallow_fusion and new_token != blank_id: - ngram_state_and_scores = shallow_fusion( - LG, - new_token, - hyp.ngram_state_and_scores, - vocab_size, - ) - if len(ngram_state_and_scores) == 0: - continue - max_ngram_score = max(ngram_state_and_scores.values()) - new_log_prob = new_log_prob + ngram_lm_scale * max_ngram_score - - # TODO: Get the maximum scores in ngram_state_and_scores - # and add it to new_log_prob - - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - ngram_state_and_scores=ngram_state_and_scores, + hyps = shallow_fusion( + LG, + A[i], + tokens, + topk_log_probs[start:end], + vocab_size, + blank_log_probs[i], ) - - B.add(new_hyp) - if len(B) == 0: - import logging - - logging.info("\n*****\nEmpty states!\n***\n") - for h in A: + for h in hyps: B.add(h) - best_hyp = B.get_most_probable(length_norm=True) + if len(B) > beam: + B = B.topk(beam, ngram_lm_scale=ngram_lm_scale) + + best_hyp = B.get_most_probable( + length_norm=True, ngram_lm_scale=ngram_lm_scale + ) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks return ys diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index b70b97d70..abd5e7fe9 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -47,7 +47,12 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search, modified_beam_search +from beam_search import ( + beam_search, + greedy_search, + modified_beam_search, + modified_beam_search_with_shallow_fusion, +) from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -283,23 +288,25 @@ def decode_one_batch( beam=params.beam_size, ) elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - LG=LG, - ngram_lm_scale=params.ngram_lm_scale, - ) + if LG is None: + hyp = modified_beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + hyp = modified_beam_search_with_shallow_fusion( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + LG=LG, + ngram_lm_scale=params.ngram_lm_scale, + ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) hyps.append(sp.decode(hyp).split()) - s = "\n" - for h in hyps: - s += " ".join(h) - s += "\n" - logging.info(s) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} @@ -349,8 +356,6 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): - if batch_idx > 10: - break texts = batch["supervisions"]["text"] hyps_dict = decode_one_batch( @@ -464,6 +469,9 @@ def main(): ), "--LG is used only when --decoding_method=modified_beam_search" logging.info(f"Loading LG from {params.LG}") LG = k2.Fsa.from_dict(torch.load(params.LG, map_location=device)) + logging.info( + f"max: {LG.scores.max()}, min: {LG.scores.min()}, mean: {LG.scores.mean()}" + ) logging.info(f"LG properties: {LG.properties_str}") logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}") # If LG is created by local/compile_lg.py, then it should be epsilon @@ -517,8 +525,6 @@ def main(): test_dl = [test_clean_dl, test_other_dl] for test_set, test_dl in zip(test_sets, test_dl): - if test_set == "test-other": - break results_dict = decode_dataset( dl=test_dl, params=params, diff --git a/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py b/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py index 8f1045d45..e74f9e350 100644 --- a/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py +++ b/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py @@ -19,39 +19,51 @@ from typing import Dict import k2 import torch import copy +from utils import Hypothesis, HypothesisList def shallow_fusion( LG: k2.Fsa, - token: int, - state_and_scores: Dict[int, torch.Tensor], + hyp: Hypothesis, + tokens: torch.Tensor, + log_probs: torch.Tensor, vocab_size: int, -) -> Dict[int, torch.Tensor]: + blank_log_prob: torch.Tensor, +) -> HypothesisList: """ Args: LG: An n-gram. It should be arc sorted, deterministic, and epsilon free. - token: - The input token ID. - state_and_scores: - The keys contain the current state we are in and the - values are the LM log_prob for reaching the corresponding - states from the start state. + It contains disambig IDs and back-off arcs. + hyp: + The current hypothesis. + tokens: + The possible tokens that will be expanded from the given `hyp`. + It is a 1-D tensor of dtype torch.int32. + log_probs: + It contains the acoustic log probabilities of each path that + is extended from `hyp.ys` with `tokens`. + log_probs.shape == tokens.shape. vocab_size: Vocabulary size, including the blank symbol. We assume that token IDs >= vocab_size are disambig IDs (including the backoff symbol #0). + blank_log_prob: + The log_prob for the blank token at this frame. It is from + the output of the joiner. Returns: - Return a new state_and_scores. + Return new hypotheses by extending the given `hyp` with tokens in the + given `tokens`. """ + row_splits = LG.arcs.row_splits(1) arcs = LG.arcs.values() - state_and_scores = copy.deepcopy(state_and_scores) + state_and_scores = copy.deepcopy(hyp.ngram_state_and_scores) current_states = list(state_and_scores.keys()) - # Process out-going arcs with label being disambig tokens and #0 + # Process out-going arcs with label equal to disambig tokens or #0 while len(current_states) > 0: s = current_states.pop() labels_begin = row_splits[s] @@ -84,7 +96,9 @@ def shallow_fusion( ) current_states = list(state_and_scores.keys()) - ans = dict() + ans = HypothesisList() + + device = log_probs.device for s in current_states: labels_begin = row_splits[s] labels_end = row_splits[s + 1] @@ -93,17 +107,47 @@ def shallow_fusion( if labels[-1] == -1: labels = labels[:-1] - pos = torch.searchsorted(labels, token) - if pos >= labels.numel() or labels[pos] != token: - continue + if s != 0: + # We add a backoff arc to the start state. Otherwise, + # all activate state may die due to out-of-Vocabulary word. + new_hyp = Hypothesis( + ys=hyp.ys[:], + log_prob=hyp.log_prob + blank_log_prob, + ngram_state_and_scores={ + # -20 is the cost on the backoff arc to the start state. + # As LG.scores.min() is about -16.6, we choose -20 here. + # You may need to tune this value. + 0: torch.full((1,), -20, dtype=torch.float32, device=device) + }, + ) + ans.add(new_hyp) - idx = labels_begin + pos - next_state = arcs[idx][1].item() - score = LG.scores[idx] + state_and_scores[s] + pos = torch.searchsorted(labels, tokens) + for i in range(pos.numel()): + if tokens[i] == 0: + # blank ID + new_hyp = Hypothesis( + ys=hyp.ys[:], + log_prob=hyp.log_prob + log_probs[i], + ngram_state_and_scores=hyp.ngram_state_and_scores, + ) + ans.add(new_hyp) + continue + elif pos[i] >= labels.numel() or labels[pos[i]] != tokens[i]: + # No out-going arcs from this state has labels + # equal to tokens[i] + continue - if next_state not in ans: - ans[next_state] = score - else: - ans[next_state] = max(score, ans[next_state]) + # Found one arc + + idx = labels_begin + pos[i] + next_state = arcs[idx][1].item() + score = LG.scores[idx] + state_and_scores[s] + new_hyp = Hypothesis( + ys=hyp.ys + [tokens[i].item()], + log_prob=hyp.log_prob + log_probs[i], + ngram_state_and_scores={next_state: score}, + ) + ans.add(new_hyp) return ans diff --git a/egs/librispeech/ASR/transducer_stateless/utils.py b/egs/librispeech/ASR/transducer_stateless/utils.py new file mode 100644 index 000000000..97f6a740d --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/utils.py @@ -0,0 +1,219 @@ +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] + + # The log prob of ys. + # It contains only one entry. + # Note: It contains only the acoustic part. + log_prob: torch.Tensor + + # Used for shallow fusion + # The key of the dict is a state index into LG + # while the corresponding value is the LM score + # reaching this state from the start state. + # Note: The value tensor contains only a single entry + # and it contains only the LM part. + ngram_state_and_scores: Optional[Dict[int, torch.Tensor]] = None + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[str, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + + if False: + old_hyp.log_prob = torch.logaddexp( + old_hyp.log_prob, hyp.log_prob + ) + else: + old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) + + if hyp.ngram_state_and_scores is not None: + for state, score in hyp.ngram_state_and_scores.items(): + if ( + state in old_hyp.ngram_state_and_scores + and score > old_hyp.ngram_state_and_scores[state] + ): + old_hyp.ngram_state_and_scores[state] = score + else: + old_hyp.ngram_state_and_scores[state] = score + else: + self._data[key] = hyp + + def get_most_probable( + self, length_norm: bool = False, ngram_lm_scale: Optional[float] = None + ) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + ngram_lm_scale: + If not None, it specifies the scale applied to the LM score. + Returns: + Return the hypothesis that has the largest `log_prob`. + """ + if length_norm: + if ngram_lm_scale is None: + return max( + self._data.values(), + key=lambda hyp: hyp.log_prob / len(hyp.ys), + ) + else: + return max( + self._data.values(), + key=lambda hyp: ( + hyp.log_prob + + ngram_lm_scale + * max(hyp.ngram_state_and_scores.values()) + ) + / len(hyp.ys), + ) + else: + if ngram_lm_scale is None: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + else: + return max( + self._data.values(), + key=lambda hyp: hyp.log_prob + + ngram_lm_scale * max(hyp.ngram_state_and_scores.values()), + ) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Caution: + `self` is modified **in-place**. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter( + self, threshold: torch.Tensor, ngram_lm_scale: Optional[float] = None + ) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Args: + threshold: + Hypotheses with log_prob less than this value are removed. + ngram_lm_scale: + If not None, it specifies the scale applied to the LM score. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + if ngram_lm_scale is None: + for _, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + else: + for _, hyp in self._data.items(): + if ( + hyp.log_prob + + ngram_lm_scale * max(hyp.ngram_state_and_scores.values()) + > threshold + ): + ans.add(hyp) # shallow copy + return ans + + def topk( + self, k: int, ngram_lm_scale: Optional[float] = None + ) -> "HypothesisList": + """Return the top-k hypothesis.""" + hyps = list(self._data.items()) + + if ngram_lm_scale is None: + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + else: + hyps = sorted( + hyps, + key=lambda h: h[1].log_prob + + ngram_lm_scale * max(h[1].ngram_state_and_scores.values()), + reverse=True, + )[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s)