Add backoff arcs to the start state to handle OOV word.

This commit is contained in:
Fangjun Kuang 2022-02-15 12:33:53 +08:00
parent 5af23efa69
commit adb54aea91
4 changed files with 459 additions and 233 deletions

View File

@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import k2 import k2
import torch import torch
from model import Transducer from model import Transducer
from shallow_fusion import shallow_fusion from shallow_fusion import shallow_fusion
from utils import Hypothesis, HypothesisList
def greedy_search( def greedy_search(
@ -103,153 +103,6 @@ def greedy_search(
return hyp return hyp
@dataclass
class Hypothesis:
# The predicted tokens so far.
# Newly predicted tokens are appended to `ys`.
ys: List[int]
# The log prob of ys.
# It contains only one entry.
log_prob: torch.Tensor
# Used for shallow fusion
# The key of the dict is a state index into LG
# while the corresponding value is the LM score
# reaching this state.
# Note: The value tensor contains only a single entry
ngram_state_and_scores: Optional[Dict[int, torch.Tensor]] = (None,)
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
return "_".join(map(str, self.ys))
class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
"""
Args:
data:
A dict of Hypotheses. Its key is its `value.key`.
"""
if data is None:
self._data = {}
else:
self._data = data
@property
def data(self) -> Dict[str, Hypothesis]:
return self._data
def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
`log-sum-exp` with the existed one.
Args:
hyp:
The hypothesis to be added.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key] # shallow copy
if True:
old_hyp.log_prob = torch.logaddexp(
old_hyp.log_prob, hyp.log_prob
)
else:
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
if hyp.ngram_state_and_scores is not None:
for state, score in hyp.ngram_state_and_scores.items():
if (
state in old_hyp.ngram_state_and_scores
and score > old_hyp.ngram_state_and_scores[state]
):
old_hyp.ngram_state_and_scores[state] = score
else:
old_hyp.ngram_state_and_scores[state] = score
else:
self._data[key] = hyp
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with
the largest `log_prob`.
Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
Returns:
Return the hypothesis that has the largest `log_prob`.
"""
if length_norm:
return max(
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
)
else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args:
hyp:
The hypothesis to be removed from `self`.
Note: It must be contained in `self`. Otherwise,
an exception is raised.
"""
key = hyp.key
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
`self` is not modified. Instead, a new HypothesisList is returned.
Returns:
Return a new HypothesisList containing all hypotheses from `self`
with `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
for _, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
return ans
def topk(self, k: int) -> "HypothesisList":
"""Return the top-k hypothesis."""
hyps = list(self._data.items())
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
ans = HypothesisList(dict(hyps))
return ans
def __contains__(self, key: str):
return key in self._data
def __iter__(self):
return iter(self._data.values())
def __len__(self) -> int:
return len(self._data)
def __str__(self) -> str:
s = []
for key in self:
s.append(key)
return ", ".join(s)
def run_decoder( def run_decoder(
ys: List[int], ys: List[int],
model: Transducer, model: Transducer,
@ -341,6 +194,113 @@ def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1)
B = HypothesisList()
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# current_encoder_out is of shape (1, 1, encoder_out_dim)
# fmt: on
A = list(B)
B = HypothesisList()
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
# ys_log_probs is of shape (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyp in A],
device=device,
)
# decoder_input is of shape (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
current_encoder_out = current_encoder_out.expand(
decoder_out.size(0), 1, -1
)
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len.expand(decoder_out.size(0)),
decoder_out_len.expand(decoder_out.size(0)),
)
# logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1)
log_probs.add_(ys_log_probs)
log_probs = log_probs.reshape(-1)
topk_log_probs, topk_indexes = log_probs.topk(beam)
# topk_hyp_indexes are indexes into `A`
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[i]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B.add(new_hyp)
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
def modified_beam_search_with_shallow_fusion(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
LG: Optional[k2.Fsa] = None, LG: Optional[k2.Fsa] = None,
ngram_lm_scale: float = 0.1, ngram_lm_scale: float = 0.1,
) -> List[int]: ) -> List[int]:
@ -408,7 +368,14 @@ def modified_beam_search(
A = list(B) A = list(B)
B = HypothesisList() B = HypothesisList()
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) # ys_log_probs contains both AM scores and LM scores
ys_log_probs = torch.cat(
[
hyp.log_prob.reshape(1, 1)
+ ngram_lm_scale * max(hyp.ngram_state_and_scores.values())
for hyp in A
]
)
# ys_log_probs is of shape (num_hyps, 1) # ys_log_probs is of shape (num_hyps, 1)
decoder_input = torch.tensor( decoder_input = torch.tensor(
@ -434,62 +401,52 @@ def modified_beam_search(
# logits is of shape (num_hyps, vocab_size) # logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) log_probs = logits.log_softmax(dim=-1)
log_probs.add_(ys_log_probs) tot_log_probs = log_probs + ys_log_probs
log_probs = log_probs.reshape(-1) _, topk_indexes = tot_log_probs.reshape(-1).topk(beam)
topk_log_probs, topk_indexes = log_probs.topk(beam) topk_log_probs = log_probs.reshape(-1)[topk_indexes]
# topk_hyp_indexes are indexes into `A` # topk_hyp_indexes are indexes into `A`
topk_hyp_indexes = topk_indexes // logits.size(-1) topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1) topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist() topk_hyp_indexes, indexes = torch.sort(topk_hyp_indexes)
topk_token_indexes = topk_token_indexes.tolist() topk_token_indexes = topk_token_indexes[indexes]
topk_log_probs = topk_log_probs[indexes]
# import pdb shape = k2.ragged.create_ragged_shape2(
# row_ids=topk_hyp_indexes.to(torch.int32),
# pdb.set_trace() cached_tot_size=topk_hyp_indexes.numel(),
for i in range(len(topk_hyp_indexes)): )
hyp = A[topk_hyp_indexes[i]] blank_log_probs = log_probs[topk_hyp_indexes, 0]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[i]
if new_token != blank_id:
new_ys.append(new_token)
else:
ngram_state_and_scores = hyp.ngram_state_and_scores
new_log_prob = topk_log_probs[i] row_splits = shape.row_splits(1).tolist()
num_rows = len(row_splits) - 1
for i in range(num_rows):
start = row_splits[i]
end = row_splits[i + 1]
if start >= end:
# Discard A[i] as other hyps have higher log_probs
continue
tokens = topk_token_indexes[start:end]
if enable_shallow_fusion and new_token != blank_id: hyps = shallow_fusion(
ngram_state_and_scores = shallow_fusion( LG,
LG, A[i],
new_token, tokens,
hyp.ngram_state_and_scores, topk_log_probs[start:end],
vocab_size, vocab_size,
) blank_log_probs[i],
if len(ngram_state_and_scores) == 0:
continue
max_ngram_score = max(ngram_state_and_scores.values())
new_log_prob = new_log_prob + ngram_lm_scale * max_ngram_score
# TODO: Get the maximum scores in ngram_state_and_scores
# and add it to new_log_prob
new_hyp = Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
ngram_state_and_scores=ngram_state_and_scores,
) )
for h in hyps:
B.add(new_hyp)
if len(B) == 0:
import logging
logging.info("\n*****\nEmpty states!\n***\n")
for h in A:
B.add(h) B.add(h)
best_hyp = B.get_most_probable(length_norm=True) if len(B) > beam:
B = B.topk(beam, ngram_lm_scale=ngram_lm_scale)
best_hyp = B.get_most_probable(
length_norm=True, ngram_lm_scale=ngram_lm_scale
)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys return ys

View File

@ -47,7 +47,12 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
beam_search,
greedy_search,
modified_beam_search,
modified_beam_search_with_shallow_fusion,
)
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -283,23 +288,25 @@ def decode_one_batch(
beam=params.beam_size, beam=params.beam_size,
) )
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search( if LG is None:
model=model, hyp = modified_beam_search(
encoder_out=encoder_out_i, model=model,
beam=params.beam_size, encoder_out=encoder_out_i,
LG=LG, beam=params.beam_size,
ngram_lm_scale=params.ngram_lm_scale, )
) else:
hyp = modified_beam_search_with_shallow_fusion(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
LG=LG,
ngram_lm_scale=params.ngram_lm_scale,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.decode(hyp).split()) hyps.append(sp.decode(hyp).split())
s = "\n"
for h in hyps:
s += " ".join(h)
s += "\n"
logging.info(s)
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
@ -349,8 +356,6 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
if batch_idx > 10:
break
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
@ -464,6 +469,9 @@ def main():
), "--LG is used only when --decoding_method=modified_beam_search" ), "--LG is used only when --decoding_method=modified_beam_search"
logging.info(f"Loading LG from {params.LG}") logging.info(f"Loading LG from {params.LG}")
LG = k2.Fsa.from_dict(torch.load(params.LG, map_location=device)) LG = k2.Fsa.from_dict(torch.load(params.LG, map_location=device))
logging.info(
f"max: {LG.scores.max()}, min: {LG.scores.min()}, mean: {LG.scores.mean()}"
)
logging.info(f"LG properties: {LG.properties_str}") logging.info(f"LG properties: {LG.properties_str}")
logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}") logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}")
# If LG is created by local/compile_lg.py, then it should be epsilon # If LG is created by local/compile_lg.py, then it should be epsilon
@ -517,8 +525,6 @@ def main():
test_dl = [test_clean_dl, test_other_dl] test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
if test_set == "test-other":
break
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,

View File

@ -19,39 +19,51 @@ from typing import Dict
import k2 import k2
import torch import torch
import copy import copy
from utils import Hypothesis, HypothesisList
def shallow_fusion( def shallow_fusion(
LG: k2.Fsa, LG: k2.Fsa,
token: int, hyp: Hypothesis,
state_and_scores: Dict[int, torch.Tensor], tokens: torch.Tensor,
log_probs: torch.Tensor,
vocab_size: int, vocab_size: int,
) -> Dict[int, torch.Tensor]: blank_log_prob: torch.Tensor,
) -> HypothesisList:
""" """
Args: Args:
LG: LG:
An n-gram. It should be arc sorted, deterministic, and epsilon free. An n-gram. It should be arc sorted, deterministic, and epsilon free.
token: It contains disambig IDs and back-off arcs.
The input token ID. hyp:
state_and_scores: The current hypothesis.
The keys contain the current state we are in and the tokens:
values are the LM log_prob for reaching the corresponding The possible tokens that will be expanded from the given `hyp`.
states from the start state. It is a 1-D tensor of dtype torch.int32.
log_probs:
It contains the acoustic log probabilities of each path that
is extended from `hyp.ys` with `tokens`.
log_probs.shape == tokens.shape.
vocab_size: vocab_size:
Vocabulary size, including the blank symbol. We assume that Vocabulary size, including the blank symbol. We assume that
token IDs >= vocab_size are disambig IDs (including the backoff token IDs >= vocab_size are disambig IDs (including the backoff
symbol #0). symbol #0).
blank_log_prob:
The log_prob for the blank token at this frame. It is from
the output of the joiner.
Returns: Returns:
Return a new state_and_scores. Return new hypotheses by extending the given `hyp` with tokens in the
given `tokens`.
""" """
row_splits = LG.arcs.row_splits(1) row_splits = LG.arcs.row_splits(1)
arcs = LG.arcs.values() arcs = LG.arcs.values()
state_and_scores = copy.deepcopy(state_and_scores) state_and_scores = copy.deepcopy(hyp.ngram_state_and_scores)
current_states = list(state_and_scores.keys()) current_states = list(state_and_scores.keys())
# Process out-going arcs with label being disambig tokens and #0 # Process out-going arcs with label equal to disambig tokens or #0
while len(current_states) > 0: while len(current_states) > 0:
s = current_states.pop() s = current_states.pop()
labels_begin = row_splits[s] labels_begin = row_splits[s]
@ -84,7 +96,9 @@ def shallow_fusion(
) )
current_states = list(state_and_scores.keys()) current_states = list(state_and_scores.keys())
ans = dict() ans = HypothesisList()
device = log_probs.device
for s in current_states: for s in current_states:
labels_begin = row_splits[s] labels_begin = row_splits[s]
labels_end = row_splits[s + 1] labels_end = row_splits[s + 1]
@ -93,17 +107,47 @@ def shallow_fusion(
if labels[-1] == -1: if labels[-1] == -1:
labels = labels[:-1] labels = labels[:-1]
pos = torch.searchsorted(labels, token) if s != 0:
if pos >= labels.numel() or labels[pos] != token: # We add a backoff arc to the start state. Otherwise,
continue # all activate state may die due to out-of-Vocabulary word.
new_hyp = Hypothesis(
ys=hyp.ys[:],
log_prob=hyp.log_prob + blank_log_prob,
ngram_state_and_scores={
# -20 is the cost on the backoff arc to the start state.
# As LG.scores.min() is about -16.6, we choose -20 here.
# You may need to tune this value.
0: torch.full((1,), -20, dtype=torch.float32, device=device)
},
)
ans.add(new_hyp)
idx = labels_begin + pos pos = torch.searchsorted(labels, tokens)
next_state = arcs[idx][1].item() for i in range(pos.numel()):
score = LG.scores[idx] + state_and_scores[s] if tokens[i] == 0:
# blank ID
new_hyp = Hypothesis(
ys=hyp.ys[:],
log_prob=hyp.log_prob + log_probs[i],
ngram_state_and_scores=hyp.ngram_state_and_scores,
)
ans.add(new_hyp)
continue
elif pos[i] >= labels.numel() or labels[pos[i]] != tokens[i]:
# No out-going arcs from this state has labels
# equal to tokens[i]
continue
if next_state not in ans: # Found one arc
ans[next_state] = score
else: idx = labels_begin + pos[i]
ans[next_state] = max(score, ans[next_state]) next_state = arcs[idx][1].item()
score = LG.scores[idx] + state_and_scores[s]
new_hyp = Hypothesis(
ys=hyp.ys + [tokens[i].item()],
log_prob=hyp.log_prob + log_probs[i],
ngram_state_and_scores={next_state: score},
)
ans.add(new_hyp)
return ans return ans

View File

@ -0,0 +1,219 @@
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Optional
import torch
@dataclass
class Hypothesis:
# The predicted tokens so far.
# Newly predicted tokens are appended to `ys`.
ys: List[int]
# The log prob of ys.
# It contains only one entry.
# Note: It contains only the acoustic part.
log_prob: torch.Tensor
# Used for shallow fusion
# The key of the dict is a state index into LG
# while the corresponding value is the LM score
# reaching this state from the start state.
# Note: The value tensor contains only a single entry
# and it contains only the LM part.
ngram_state_and_scores: Optional[Dict[int, torch.Tensor]] = None
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
return "_".join(map(str, self.ys))
class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
"""
Args:
data:
A dict of Hypotheses. Its key is its `value.key`.
"""
if data is None:
self._data = {}
else:
self._data = data
@property
def data(self) -> Dict[str, Hypothesis]:
return self._data
def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
`log-sum-exp` with the existed one.
Args:
hyp:
The hypothesis to be added.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key] # shallow copy
if False:
old_hyp.log_prob = torch.logaddexp(
old_hyp.log_prob, hyp.log_prob
)
else:
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
if hyp.ngram_state_and_scores is not None:
for state, score in hyp.ngram_state_and_scores.items():
if (
state in old_hyp.ngram_state_and_scores
and score > old_hyp.ngram_state_and_scores[state]
):
old_hyp.ngram_state_and_scores[state] = score
else:
old_hyp.ngram_state_and_scores[state] = score
else:
self._data[key] = hyp
def get_most_probable(
self, length_norm: bool = False, ngram_lm_scale: Optional[float] = None
) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with
the largest `log_prob`.
Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
ngram_lm_scale:
If not None, it specifies the scale applied to the LM score.
Returns:
Return the hypothesis that has the largest `log_prob`.
"""
if length_norm:
if ngram_lm_scale is None:
return max(
self._data.values(),
key=lambda hyp: hyp.log_prob / len(hyp.ys),
)
else:
return max(
self._data.values(),
key=lambda hyp: (
hyp.log_prob
+ ngram_lm_scale
* max(hyp.ngram_state_and_scores.values())
)
/ len(hyp.ys),
)
else:
if ngram_lm_scale is None:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
else:
return max(
self._data.values(),
key=lambda hyp: hyp.log_prob
+ ngram_lm_scale * max(hyp.ngram_state_and_scores.values()),
)
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args:
hyp:
The hypothesis to be removed from `self`.
Note: It must be contained in `self`. Otherwise,
an exception is raised.
"""
key = hyp.key
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(
self, threshold: torch.Tensor, ngram_lm_scale: Optional[float] = None
) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
`self` is not modified. Instead, a new HypothesisList is returned.
Args:
threshold:
Hypotheses with log_prob less than this value are removed.
ngram_lm_scale:
If not None, it specifies the scale applied to the LM score.
Returns:
Return a new HypothesisList containing all hypotheses from `self`
with `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
if ngram_lm_scale is None:
for _, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
else:
for _, hyp in self._data.items():
if (
hyp.log_prob
+ ngram_lm_scale * max(hyp.ngram_state_and_scores.values())
> threshold
):
ans.add(hyp) # shallow copy
return ans
def topk(
self, k: int, ngram_lm_scale: Optional[float] = None
) -> "HypothesisList":
"""Return the top-k hypothesis."""
hyps = list(self._data.items())
if ngram_lm_scale is None:
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
else:
hyps = sorted(
hyps,
key=lambda h: h[1].log_prob
+ ngram_lm_scale * max(h[1].ngram_state_and_scores.values()),
reverse=True,
)[:k]
ans = HypothesisList(dict(hyps))
return ans
def __contains__(self, key: str):
return key in self._data
def __iter__(self):
return iter(self._data.values())
def __len__(self) -> int:
return len(self._data)
def __str__(self) -> str:
s = []
for key in self:
s.append(key)
return ", ".join(s)