Fix beam search

This commit is contained in:
pkufool 2022-01-07 14:42:41 +08:00
parent f4b8b0641a
commit bfd331c27e

View File

@ -128,7 +128,6 @@ class HypothesisList(object):
def data(self):
return self._data
# def add(self, ys: List[int], log_prob: float):
def add(self, hyp: Hypothesis):
"""Add a Hypothesis to `self`.
@ -266,7 +265,7 @@ def beam_search(
while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on
A = B
B = HypothesisList()
@ -294,9 +293,11 @@ def beam_search(
cached_key += f"-t-{t}"
if cached_key not in joint_cache:
logits = model.joiner(current_encoder_out, decoder_out)
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1)
)
# TODO(fangjun): Ccale the blank posterior
# TODO(fangjun): Cache the blank posterior
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)