mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Fix beam search
This commit is contained in:
parent
f4b8b0641a
commit
bfd331c27e
@ -128,7 +128,6 @@ class HypothesisList(object):
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
# def add(self, ys: List[int], log_prob: float):
|
||||
def add(self, hyp: Hypothesis):
|
||||
"""Add a Hypothesis to `self`.
|
||||
|
||||
@ -266,7 +265,7 @@ def beam_search(
|
||||
|
||||
while t < T and sym_per_utt < max_sym_per_utt:
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||
# fmt: on
|
||||
A = B
|
||||
B = HypothesisList()
|
||||
@ -294,9 +293,11 @@ def beam_search(
|
||||
|
||||
cached_key += f"-t-{t}"
|
||||
if cached_key not in joint_cache:
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out.unsqueeze(1)
|
||||
)
|
||||
|
||||
# TODO(fangjun): Ccale the blank posterior
|
||||
# TODO(fangjun): Cache the blank posterior
|
||||
|
||||
log_prob = logits.log_softmax(dim=-1)
|
||||
# log_prob is (1, 1, 1, vocab_size)
|
||||
|
Loading…
x
Reference in New Issue
Block a user