mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
Fix beam search
This commit is contained in:
parent
f4b8b0641a
commit
bfd331c27e
@ -128,7 +128,6 @@ class HypothesisList(object):
|
|||||||
def data(self):
|
def data(self):
|
||||||
return self._data
|
return self._data
|
||||||
|
|
||||||
# def add(self, ys: List[int], log_prob: float):
|
|
||||||
def add(self, hyp: Hypothesis):
|
def add(self, hyp: Hypothesis):
|
||||||
"""Add a Hypothesis to `self`.
|
"""Add a Hypothesis to `self`.
|
||||||
|
|
||||||
@ -266,7 +265,7 @@ def beam_search(
|
|||||||
|
|
||||||
while t < T and sym_per_utt < max_sym_per_utt:
|
while t < T and sym_per_utt < max_sym_per_utt:
|
||||||
# fmt: off
|
# fmt: off
|
||||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
A = B
|
A = B
|
||||||
B = HypothesisList()
|
B = HypothesisList()
|
||||||
@ -294,9 +293,11 @@ def beam_search(
|
|||||||
|
|
||||||
cached_key += f"-t-{t}"
|
cached_key += f"-t-{t}"
|
||||||
if cached_key not in joint_cache:
|
if cached_key not in joint_cache:
|
||||||
logits = model.joiner(current_encoder_out, decoder_out)
|
logits = model.joiner(
|
||||||
|
current_encoder_out, decoder_out.unsqueeze(1)
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(fangjun): Ccale the blank posterior
|
# TODO(fangjun): Cache the blank posterior
|
||||||
|
|
||||||
log_prob = logits.log_softmax(dim=-1)
|
log_prob = logits.log_softmax(dim=-1)
|
||||||
# log_prob is (1, 1, 1, vocab_size)
|
# log_prob is (1, 1, 1, vocab_size)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user