From bfd331c27ebf457e9f9f567bdf02a253a93ce8ac Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 7 Jan 2022 14:42:41 +0800 Subject: [PATCH] Fix beam search --- egs/librispeech/ASR/transducer_stateless/beam_search.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index f347f552f..ee7c4160a 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -128,7 +128,6 @@ class HypothesisList(object): def data(self): return self._data - # def add(self, ys: List[int], log_prob: float): def add(self, hyp: Hypothesis): """Add a Hypothesis to `self`. @@ -266,7 +265,7 @@ def beam_search( while t < T and sym_per_utt < max_sym_per_utt: # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) # fmt: on A = B B = HypothesisList() @@ -294,9 +293,11 @@ def beam_search( cached_key += f"-t-{t}" if cached_key not in joint_cache: - logits = model.joiner(current_encoder_out, decoder_out) + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1) + ) - # TODO(fangjun): Ccale the blank posterior + # TODO(fangjun): Cache the blank posterior log_prob = logits.log_softmax(dim=-1) # log_prob is (1, 1, 1, vocab_size)