mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Minor fixes.
This commit is contained in:
parent
77261bc575
commit
534cac4406
@ -109,8 +109,6 @@ class Hypothesis:
|
||||
|
||||
# The log prob of ys.
|
||||
# It contains only one entry.
|
||||
# TODO(fangjun): It was a float before. We need to change its usage
|
||||
# in greedy_search and beam_search.
|
||||
log_prob: torch.Tensor
|
||||
|
||||
@property
|
||||
@ -188,7 +186,7 @@ class HypothesisList(object):
|
||||
assert key in self, f"{key} does not exist"
|
||||
del self._data[key]
|
||||
|
||||
def filter(self, threshold: float) -> "HypothesisList":
|
||||
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||
|
||||
Caution:
|
||||
@ -462,7 +460,12 @@ def beam_search(
|
||||
t = 0
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
|
||||
max_sym_per_utt = 20000
|
||||
|
||||
@ -482,9 +485,6 @@ def beam_search(
|
||||
|
||||
joint_cache: Dict[str, torch.Tensor] = {}
|
||||
|
||||
# TODO(fangjun): Implement prefix search to update the `log_prob`
|
||||
# of hypotheses in A
|
||||
|
||||
while True:
|
||||
y_star = A.get_most_probable()
|
||||
A.remove(y_star)
|
||||
@ -507,18 +507,21 @@ def beam_search(
|
||||
|
||||
# First, process the blank symbol
|
||||
skip_log_prob = log_prob[blank_id]
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||
|
||||
# ys[:] returns a copy of ys
|
||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
||||
|
||||
# Second, process other non-blank labels
|
||||
values, indices = log_prob.topk(beam + 1)
|
||||
for i, v in zip(indices.tolist(), values.tolist()):
|
||||
for idx in range(values.size(0)):
|
||||
i = indices[idx].item()
|
||||
if i == blank_id:
|
||||
continue
|
||||
|
||||
new_ys = y_star.ys + [i]
|
||||
new_log_prob = y_star.log_prob + v
|
||||
|
||||
new_log_prob = y_star.log_prob + values[idx]
|
||||
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
||||
|
||||
# Check whether B contains more than "beam" elements more probable
|
||||
|
Loading…
x
Reference in New Issue
Block a user