From fbc1bc3a6b90d1bd6d6238b041fbf73649471e00 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 22 Dec 2021 15:15:58 +0800 Subject: [PATCH] Implement beam search. --- egs/librispeech/ASR/README.md | 14 + egs/librispeech/ASR/RESULTS.md | 5 +- .../ASR/transducer_stateless/beam_search.py | 284 ++++++++++++------ 3 files changed, 212 insertions(+), 91 deletions(-) diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index ae0c2684d..76113eff7 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -1,3 +1,17 @@ +# Introduction + Please refer to for how to run models in this recipe. + +# Transducers + +There are various folders containing the name `transducer` in this folder. +The following table lists the differences among them. + +| | Encoder | Decoder | +|------------------------|-----------|--------------------| +| `transducer` | Conformer | LSTM | +| `transducer_stateless` | Conformer | Conv1d + Embedding | + + diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 1f5edb571..19f5b18a7 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -3,8 +3,9 @@ ### LibriSpeech BPE training results (RNN-T) #### 2021-12-17 +Using commit `cb04c8a7509425ab45fae888b0ca71bbbd23f0de`. -RNN-T + Conformer encoder +RNN-T + Conformer encoder. The best WER is @@ -12,7 +13,7 @@ The best WER is |-----|------------|------------| | WER | 3.16 | 7.71 | -using `--epoch 26 --avg 12` during decoding with greedy search. +using `--epoch 26 --avg 12` with **greedy search**. The training command to reproduce the above WER is: diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 056e7c372..34b0e9b53 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -15,8 +15,9 @@ # limitations under the License. from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional +import numpy as np import torch from model import Transducer @@ -35,25 +36,35 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id context_size = model.decoder.context_size device = model.device - sos = torch.tensor([blank_id] * context_size, device=device).reshape( - 1, context_size - ) - decoder_out = model.decoder(sos, need_pad=False) + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + T = encoder_out.size(1) t = 0 hyp = [blank_id] * context_size - sym_per_frame = 0 - sym_per_utt = 0 - + # Maximum symbols per utterance. max_sym_per_utt = 1000 + + # If at frame t, it decodes more than this number of symbols, + # it will move to the next step t+1 max_sym_per_frame = 3 + # symbols per frame + sym_per_frame = 0 + + # symbols per utterance decoded so far + sym_per_utt = 0 + while t < T and sym_per_utt < max_sym_per_utt: # fmt: off current_encoder_out = encoder_out[:, t:t+1, :] @@ -83,18 +94,125 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: @dataclass class Hypothesis: - ys: List[int] # the predicted sequences so far - log_prob: float # The log prob of ys + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] - # Optional decoder state. We assume it is LSTM for now, - # so the state is a tuple (h, c) - decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + # The log prob of ys + log_prob: float + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = {}): + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + self._data = data + + @property + def data(self): + return self._data + + # def add(self, ys: List[int], log_prob: float): + def add(self, hyp: Hypothesis): + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] + old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + + """ + if length_norm: + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) + else: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: float) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + that have `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + for key, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int) -> "HypothesisList": + """Return the top-k hypothesis.""" + hyps = list(self._data.items()) + + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s) def beam_search( model: Transducer, encoder_out: torch.Tensor, - beam: int = 5, + beam: int = 4, ) -> List[int]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -116,110 +234,98 @@ def beam_search( # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id - sos_id = model.decoder.sos_id + context_size = model.decoder.context_size + device = model.device - sos = torch.tensor([blank_id], device=device).reshape(1, 1) - decoder_out, (h, c) = model.decoder(sos) + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + T = encoder_out.size(1) t = 0 - B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)] - max_u = 20000 # terminate after this number of steps - u = 0 - cache: Dict[ - str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - ] = {} + B = HypothesisList() + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) - while t < T and u < max_u: + max_sym_per_utt = 20000 + + sym_per_utt = 0 + + decoder_cache: Dict[str, torch.Tensor] = {} + + while t < T and sym_per_utt < max_sym_per_utt: # fmt: off current_encoder_out = encoder_out[:, t:t+1, :] # fmt: on A = B - B = [] - # for hyp in A: - # for h in A: - # if h.ys == hyp.ys[:-1]: - # # update the score of hyp - # decoder_input = torch.tensor( - # [h.ys[-1]], device=device - # ).reshape(1, 1) - # decoder_out, _ = model.decoder( - # decoder_input, h.decoder_state - # ) - # logits = model.joiner(current_encoder_out, decoder_out) - # log_prob = logits.log_softmax(dim=-1) - # log_prob = log_prob.squeeze() - # hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item() + B = HypothesisList() - while u < max_u: - y_star = max(A, key=lambda hyp: hyp.log_prob) + joint_cache: Dict[str, torch.Tensor] = {} + + # TODO(fangjun): Implement prefix search to update the `log_prob` + # of hypotheses in A + + while True: + y_star = A.get_most_probable() A.remove(y_star) - # Note: y_star.ys is unhashable, i.e., cannot be used - # as a key into a dict - cached_key = "_".join(map(str, y_star.ys)) + cached_key = y_star.key - if cached_key not in cache: + if cached_key not in decoder_cache: decoder_input = torch.tensor( - [y_star.ys[-1]], device=device - ).reshape(1, 1) + [y_star.ys[-context_size:]], device=device + ).reshape(1, context_size) - decoder_out, decoder_state = model.decoder( - decoder_input, - y_star.decoder_state, - ) - cache[cached_key] = (decoder_out, decoder_state) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_cache[cached_key] = decoder_out else: - decoder_out, decoder_state = cache[cached_key] + decoder_out = decoder_cache[cached_key] - logits = model.joiner(current_encoder_out, decoder_out) - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) + cached_key += f"-t-{t}" + if cached_key not in joint_cache: + logits = model.joiner(current_encoder_out, decoder_out) - # If we choose blank here, add the new hypothesis to B. - # Otherwise, add the new hypothesis to A + # TODO(fangjun): Ccale the blank posterior - # First, choose blank + log_prob = logits.log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + joint_cache[cached_key] = log_prob + else: + log_prob = joint_cache[cached_key] + + # First, process the blank symbol skip_log_prob = log_prob[blank_id] new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() # ys[:] returns a copy of ys - new_y_star = Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - # Caution: Use y_star.decoder_state here - decoder_state=y_star.decoder_state, - ) - B.append(new_y_star) + B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) - # Second, choose other labels - for i, v in enumerate(log_prob.tolist()): - if i in (blank_id, sos_id): + # Second, process other non-blank labels + values, indices = log_prob.topk(beam + 1) + for i, v in zip(indices.tolist(), values.tolist()): + if i == blank_id: continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - decoder_state=decoder_state, - ) - A.append(new_hyp) - u += 1 - # check whether B contains more than "beam" elements more probable + A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + + # Check whether B contains more than "beam" elements more probable # than the most probable in A - A_most_probable = max(A, key=lambda hyp: hyp.log_prob) - B = sorted( - [hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob], - key=lambda hyp: hyp.log_prob, - reverse=True, - ) - if len(B) >= beam: - B = B[:beam] + A_most_probable = A.get_most_probable() + + kept_B = B.filter(A_most_probable.log_prob) + + if len(kept_B) >= beam: + B = kept_B.topk(beam) break + t += 1 - best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:])) - ys = best_hyp.ys[1:] # [1:] to remove the blank + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks return ys