Minor fixes.

This commit is contained in:
Fangjun Kuang 2022-01-29 12:59:01 +08:00
parent 77261bc575
commit 534cac4406

View File

@ -109,8 +109,6 @@ class Hypothesis:
# The log prob of ys.
# It contains only one entry.
# TODO(fangjun): It was a float before. We need to change its usage
# in greedy_search and beam_search.
log_prob: torch.Tensor
@property
@ -188,7 +186,7 @@ class HypothesisList(object):
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(self, threshold: float) -> "HypothesisList":
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
@ -462,7 +460,12 @@ def beam_search(
t = 0
B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
max_sym_per_utt = 20000
@ -482,9 +485,6 @@ def beam_search(
joint_cache: Dict[str, torch.Tensor] = {}
# TODO(fangjun): Implement prefix search to update the `log_prob`
# of hypotheses in A
while True:
y_star = A.get_most_probable()
A.remove(y_star)
@ -507,18 +507,21 @@ def beam_search(
# First, process the blank symbol
skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
# Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1)
for i, v in zip(indices.tolist(), values.tolist()):
for idx in range(values.size(0)):
i = indices[idx].item()
if i == blank_id:
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
new_log_prob = y_star.log_prob + values[idx]
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
# Check whether B contains more than "beam" elements more probable