mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
Minor fixes.
This commit is contained in:
parent
77261bc575
commit
534cac4406
@ -109,8 +109,6 @@ class Hypothesis:
|
|||||||
|
|
||||||
# The log prob of ys.
|
# The log prob of ys.
|
||||||
# It contains only one entry.
|
# It contains only one entry.
|
||||||
# TODO(fangjun): It was a float before. We need to change its usage
|
|
||||||
# in greedy_search and beam_search.
|
|
||||||
log_prob: torch.Tensor
|
log_prob: torch.Tensor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -188,7 +186,7 @@ class HypothesisList(object):
|
|||||||
assert key in self, f"{key} does not exist"
|
assert key in self, f"{key} does not exist"
|
||||||
del self._data[key]
|
del self._data[key]
|
||||||
|
|
||||||
def filter(self, threshold: float) -> "HypothesisList":
|
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||||
"""Remove all Hypotheses whose log_prob is less than threshold.
|
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||||
|
|
||||||
Caution:
|
Caution:
|
||||||
@ -462,7 +460,12 @@ def beam_search(
|
|||||||
t = 0
|
t = 0
|
||||||
|
|
||||||
B = HypothesisList()
|
B = HypothesisList()
|
||||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
|
B.add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
max_sym_per_utt = 20000
|
max_sym_per_utt = 20000
|
||||||
|
|
||||||
@ -482,9 +485,6 @@ def beam_search(
|
|||||||
|
|
||||||
joint_cache: Dict[str, torch.Tensor] = {}
|
joint_cache: Dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
# TODO(fangjun): Implement prefix search to update the `log_prob`
|
|
||||||
# of hypotheses in A
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
y_star = A.get_most_probable()
|
y_star = A.get_most_probable()
|
||||||
A.remove(y_star)
|
A.remove(y_star)
|
||||||
@ -507,18 +507,21 @@ def beam_search(
|
|||||||
|
|
||||||
# First, process the blank symbol
|
# First, process the blank symbol
|
||||||
skip_log_prob = log_prob[blank_id]
|
skip_log_prob = log_prob[blank_id]
|
||||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
|
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||||
|
|
||||||
# ys[:] returns a copy of ys
|
# ys[:] returns a copy of ys
|
||||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
||||||
|
|
||||||
# Second, process other non-blank labels
|
# Second, process other non-blank labels
|
||||||
values, indices = log_prob.topk(beam + 1)
|
values, indices = log_prob.topk(beam + 1)
|
||||||
for i, v in zip(indices.tolist(), values.tolist()):
|
for idx in range(values.size(0)):
|
||||||
|
i = indices[idx].item()
|
||||||
if i == blank_id:
|
if i == blank_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
new_ys = y_star.ys + [i]
|
new_ys = y_star.ys + [i]
|
||||||
new_log_prob = y_star.log_prob + v
|
|
||||||
|
new_log_prob = y_star.log_prob + values[idx]
|
||||||
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
||||||
|
|
||||||
# Check whether B contains more than "beam" elements more probable
|
# Check whether B contains more than "beam" elements more probable
|
||||||
|
Loading…
x
Reference in New Issue
Block a user