diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 989caa802..e3bdb1457 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -118,7 +118,7 @@ class Hypothesis: class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: """ Args: data: @@ -130,11 +130,10 @@ class HypothesisList(object): self._data = data @property - def data(self): + def data(self) -> Dict[str, Hypothesis]: return self._data - # def add(self, ys: List[int], log_prob: float): - def add(self, hyp: Hypothesis): + def add(self, hyp: Hypothesis) -> None: """Add a Hypothesis to `self`. If `hyp` already exists in `self`, its probability is updated using @@ -159,7 +158,8 @@ class HypothesisList(object): length_norm: If True, the `log_prob` of a hypothesis is normalized by the number of tokens in it. - + Returns: + Return the hypothesis that has the largest `log_prob`. """ if length_norm: return max( @@ -171,6 +171,9 @@ class HypothesisList(object): def remove(self, hyp: Hypothesis) -> None: """Remove a given hypothesis. + Caution: + `self` is modified **in-place**. + Args: hyp: The hypothesis to be removed from `self`. @@ -189,10 +192,10 @@ class HypothesisList(object): Returns: Return a new HypothesisList containing all hypotheses from `self` - that have `log_prob` being greater than the given `threshold`. + with `log_prob` being greater than the given `threshold`. """ ans = HypothesisList() - for key, hyp in self._data.items(): + for _, hyp in self._data.items(): if hyp.log_prob > threshold: ans.add(hyp) # shallow copy return ans @@ -222,6 +225,171 @@ class HypothesisList(object): return ", ".join(s) +def run_decoder( + ys: List[int], + model: Transducer, + decoder_cache: Dict[str, torch.Tensor], +) -> torch.Tensor: + """Run the neural decoder model for a given hypothesis. + + Args: + ys: + The current hypothesis. + model: + The transducer model. + decoder_cache: + Cache to save computations. + Returns: + Return a 1-D tensor of shape (decoder_out_dim,) containing + output of `model.decoder`. + """ + context_size = model.decoder.context_size + key = "_".join(map(str, ys[-context_size:])) + if key in decoder_cache: + return decoder_cache[key] + + device = model.device + + decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape( + 1, context_size + ) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_cache[key] = decoder_out + + return decoder_out + + +def run_joiner( + key: str, + model: Transducer, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + encoder_out_len: torch.Tensor, + decoder_out_len: torch.Tensor, + joint_cache: Dict[str, torch.Tensor], +): + """Run the joint network given outputs from the encoder and decoder. + + Args: + key: + A key into the `joint_cache`. + model: + The transducer model. + encoder_out: + A tensor of shape (1, 1, encoder_out_dim). + decoder_out: + A tensor of shape (1, 1, decoder_out_dim). + encoder_out_len: + A tensor with value [1]. + decoder_out_len: + A tensor with value [1]. + joint_cache: + A dict to save computations. + Returns: + Return a tensor from the output of log-softmax. + Its shape is (vocab_size,). + """ + if key in joint_cache: + return joint_cache[key] + + logits = model.joiner( + encoder_out, + decoder_out, + encoder_out_len, + decoder_out_len, + ) + + # TODO(fangjun): Scale the blank posterior + log_prob = logits.log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + + joint_cache[key] = log_prob + + return log_prob + + +def start_with(a: List[int], b: List[int]) -> bool: + """Check whether a is started with b, i.e., whether a[len(b)] == b""" + a_len = len(a) + b_len = len(b) + if b_len > a_len: + return False + + for i in range(b_len): + if a[i] != b[i]: + return False + return True + + +# The implementation uses +# espnet/nets/beam_search_transducer.py#L168 +# as a reference +def prefix_search( + hyp_list: HypothesisList, + model: Transducer, + encoder_out: torch.Tensor, + decoder_cache: Dict[str, torch.Tensor], + joint_cache: Dict[str, torch.Tensor], + t: int, +): + hyps = list(hyp_list) + + # sort hyps by number of tokens in descending order + hyps = sorted(hyps, key=lambda h: len(h.ys), reverse=True) + + prefix_alpha = 1 + + device = model.device + context_size = model.decoder.context_size + + encoder_out_len = torch.tensor([1]) + decoder_out_len = torch.tensor([1]) + + for i, cur_hyp in enumerate(hyps[:-1]): + cur_hyp_len = len(cur_hyp.ys) + for next_hyp in hyps[i + 1 :]: # noqa + if not start_with(cur_hyp.ys, next_hyp.ys): + continue + + next_hyp_len = len(next_hyp.ys) + + # at this point, next_hyp.ys is a prefix of cur_hyp.ys + len_diff = cur_hyp_len - next_hyp_len + if len_diff > prefix_alpha: + continue + offset = next_hyp_len + + total_log_prob = next_hyp.log_prob + for i in range(len_diff): + pos = offset + i + ys = cur_hyp.ys[:pos] + + decoder_out = run_decoder( + ys=ys, model=model, decoder_cache=decoder_cache + ) + + key = "_".join(map(str, ys[-context_size:])) + key += f"-t-{t}" + log_prob = run_joiner( + key=key, + model=model, + encoder_out=encoder_out, + decoder_out=decoder_out, + encoder_out_len=encoder_out_len, + decoder_out_len=decoder_out_len, + joint_cache=joint_cache, + ) + total_log_prob += log_prob[cur_hyp.ys[pos]].item() + cur_hyp.log_prob = np.logaddexp(total_log_prob, cur_hyp.log_prob) + + ans = {hyp.key: hyp for hyp in hyps} + return HypothesisList(ans) + + def beam_search( model: Transducer, encoder_out: torch.Tensor, @@ -281,43 +449,34 @@ def beam_search( joint_cache: Dict[str, torch.Tensor] = {} - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A + A = prefix_search( + hyp_list=A, + model=model, + encoder_out=current_encoder_out, + decoder_cache=decoder_cache, + joint_cache=joint_cache, + t=t, + ) while True: y_star = A.get_most_probable() A.remove(y_star) - cached_key = y_star.key + decoder_out = run_decoder( + ys=y_star.ys, model=model, decoder_cache=decoder_cache + ) - if cached_key not in decoder_cache: - decoder_input = torch.tensor( - [y_star.ys[-context_size:]], device=device - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_cache[cached_key] = decoder_out - else: - decoder_out = decoder_cache[cached_key] - - cached_key += f"-t-{t}" - if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, - decoder_out, - encoder_out_len, - decoder_out_len, - ) - - # TODO(fangjun): Ccale the blank posterior - - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - joint_cache[cached_key] = log_prob - else: - log_prob = joint_cache[cached_key] + key = "_".join(map(str, y_star.ys[-context_size:])) + key += f"-t-{t}" + log_prob = run_joiner( + key=key, + model=model, + encoder_out=current_encoder_out, + decoder_out=decoder_out, + encoder_out_len=encoder_out_len, + decoder_out_len=decoder_out_len, + joint_cache=joint_cache, + ) # First, process the blank symbol skip_log_prob = log_prob[blank_id]