From 3174bebf073cf5f1b567de432ca1e41d0862e0da Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 15 Dec 2021 18:50:29 +0800 Subject: [PATCH] Add beam search. --- .gitignore | 1 + egs/librispeech/ASR/transducer/beam_search.py | 150 +++++++++++++++++- egs/librispeech/ASR/transducer/decode.py | 69 +++++++- 3 files changed, 212 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 31da5ed3e..870d3cea3 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ exp*/ download *.bak *-bak +*bak.py diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py index 8eb19c7af..62ad14257 100644 --- a/egs/librispeech/ASR/transducer/beam_search.py +++ b/egs/librispeech/ASR/transducer/beam_search.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple import torch from transducer.model import Transducer @@ -50,9 +51,10 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: current_encoder_out = encoder_out[:, t:t+1, :] # fmt: on logits = model.joiner(current_encoder_out, decoder_out) + # logits is (1, 1, 1, vocab_size) log_prob = logits.log_softmax(dim=-1) - # log_prob is (N, 1, 1) + # log_prob is (1, 1, 1, vocab_size) # TODO: Use logits.argmax() y = log_prob.argmax() if y != blank_id: @@ -64,3 +66,147 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: t += 1 return hyp + + +@dataclass +class Hypothesis: + ys: List[int] # the predicated sequences so far + log_prob: float # The log prob of ys + + # Optional decoder state. We assume it is LSTM for now, + # so the state is a tuple (h, c) + decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + +def beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 5, +) -> List[int]: + """ + It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf + + espnet/nets/beam_search_transducer.py#L247 is used as a reference. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + sos_id = model.decoder.sos_id + device = model.device + + sos = torch.tensor([blank_id], device=device).reshape(1, 1) + decoder_out, (h, c) = model.decoder(sos) + T = encoder_out.size(1) + t = 0 + B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)] + max_u = 20000 # terminate after this number of steps + u = 0 + + cache: Dict[ + str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = {} + + while t < T and u < max_u: + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :] + # fmt: on + A = B + B = [] + # for hyp in A: + # for h in A: + # if h.ys == hyp.ys[:-1]: + # # update the score of hyp + # decoder_input = torch.tensor( + # [h.ys[-1]], device=device + # ).reshape(1, 1) + # decoder_out, _ = model.decoder( + # decoder_input, h.decoder_state + # ) + # logits = model.joiner(current_encoder_out, decoder_out) + # log_prob = logits.log_softmax(dim=-1) + # log_prob = log_prob.squeeze() + # hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item() + + while u < max_u: + y_star = max(A, key=lambda hyp: hyp.log_prob) + A.remove(y_star) + + # Note: y_star.ys is unhashable, i.e., cannot be used + # as a key into a dict + cached_key = "_".join(map(str, y_star.ys)) + + if cached_key not in cache: + decoder_input = torch.tensor( + [y_star.ys[-1]], device=device + ).reshape(1, 1) + + decoder_out, decoder_state = model.decoder( + decoder_input, + y_star.decoder_state, + ) + cache[cached_key] = (decoder_out, decoder_state) + else: + decoder_out, decoder_state = cache[cached_key] + + logits = model.joiner(current_encoder_out, decoder_out) + log_prob = logits.log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + + # If we choose blank here, add the new hypothesis to B. + # Otherwise, add the new hypothesis to A + + # First, choose blank + skip_log_prob = log_prob[blank_id] + new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() + + # ys[:] returns a copy of ys + new_y_star = Hypothesis( + ys=y_star.ys[:], + log_prob=new_y_star_log_prob, + # Caution: Use y_star.decoder_state here + decoder_state=y_star.decoder_state, + ) + B.append(new_y_star) + + # Second, choose other labels + for i, v in enumerate(log_prob.tolist()): + if i in (blank_id, sos_id): + continue + new_ys = y_star.ys + [i] + new_log_prob = y_star.log_prob + v + new_hyp = Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + decoder_state=decoder_state, + ) + A.append(new_hyp) + u += 1 + # check whether B contains more than "beam" elements more probable + # than the most probable in A + A_most_probable = max(A, key=lambda hyp: hyp.log_prob) + B = sorted( + [hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob], + key=lambda hyp: hyp.log_prob, + reverse=True, + ) + if len(B) >= beam: + B = B[:beam] + break + t += 1 + best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:])) + ys = best_hyp.ys[1:] # [1:] to remove the blank + return ys diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py index 7bb4060b2..2d7fcf41d 100755 --- a/egs/librispeech/ASR/transducer/decode.py +++ b/egs/librispeech/ASR/transducer/decode.py @@ -15,6 +15,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Usage: +(1) greedy search +./transducer/decode.py \ + --epoch 14 \ + --avg 7 \ + --exp-dir ./transducer/exp \ + --max-duration 100 \ + --decoding-method greedy_search +(2) beam search + +./transducer/decode.py \ + --epoch 14 \ + --avg 7 \ + --exp-dir ./transducer/exp \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 8 +""" import argparse @@ -27,7 +46,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from transducer.beam_search import greedy_search +from transducer.beam_search import beam_search, greedy_search from transducer.conformer import Conformer from transducer.decoder import Decoder from transducer.joiner import Joiner @@ -78,6 +97,23 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=5, + help="Used only when --decoding-method is beam_search", + ) + return parser @@ -205,11 +241,22 @@ def decode_one_batch( # fmt: off encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] # fmt: on - hyp = greedy_search(model=model, encoder_out=encoder_out_i) + if params.decoding_method == "greedy_search": + hyp = greedy_search(model=model, encoder_out=encoder_out_i) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) hyps.append(sp.decode(hyp).split()) - return {"greedy_search": hyps} - # TODO: Implement beam search + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + else: + return {f"beam_{params.beam_size}": hyps} def decode_dataset( @@ -243,6 +290,11 @@ def decode_dataset( except TypeError: num_batches = "?" + if params.decoding_method == "greedy_search": + log_interval = 100 + else: + log_interval = 2 + results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] @@ -265,7 +317,7 @@ def decode_dataset( num_cuts += len(texts) - if batch_idx % 100 == 0: + if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" logging.info( @@ -327,8 +379,13 @@ def main(): params = get_params() params.update(vars(args)) - params.res_dir = params.exp_dir / "greedy_search" + + assert params.decoding_method in ("greedy_search", "beam_search") + params.res_dir = params.exp_dir / params.decoding_method + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if params.decoding_method == "beam_search": + params.suffix += f"-beam-{params.beam_size}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started")