diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 3d4818509..1a2aa487e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -48,7 +48,7 @@ def greedy_search( device = model.device decoder_input = torch.tensor( - [blank_id] * context_size, device=device + [blank_id] * context_size, device=device, dtype=torch.int64 ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -103,8 +103,9 @@ class Hypothesis: # Newly predicted tokens are appended to `ys`. ys: List[int] - # The log prob of ys - log_prob: float + # The log prob of ys. + # It contains only one entry. + log_prob: torch.Tensor @property def key(self) -> str: @@ -113,7 +114,7 @@ class Hypothesis: class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: """ Args: data: @@ -125,10 +126,10 @@ class HypothesisList(object): self._data = data @property - def data(self): + def data(self) -> Dict[str, Hypothesis]: return self._data - def add(self, hyp: Hypothesis): + def add(self, hyp: Hypothesis) -> None: """Add a Hypothesis to `self`. If `hyp` already exists in `self`, its probability is updated using @@ -140,8 +141,10 @@ class HypothesisList(object): """ key = hyp.key if key in self: - old_hyp = self._data[key] - old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) + old_hyp = self._data[key] # shallow copy + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -153,7 +156,8 @@ class HypothesisList(object): length_norm: If True, the `log_prob` of a hypothesis is normalized by the number of tokens in it. - + Returns: + Return the hypothesis that has the largest `log_prob`. """ if length_norm: return max( @@ -165,6 +169,9 @@ class HypothesisList(object): def remove(self, hyp: Hypothesis) -> None: """Remove a given hypothesis. + Caution: + `self` is modified **in-place**. + Args: hyp: The hypothesis to be removed from `self`. @@ -175,7 +182,7 @@ class HypothesisList(object): assert key in self, f"{key} does not exist" del self._data[key] - def filter(self, threshold: float) -> "HypothesisList": + def filter(self, threshold: torch.Tensor) -> "HypothesisList": """Remove all Hypotheses whose log_prob is less than threshold. Caution: @@ -183,10 +190,10 @@ class HypothesisList(object): Returns: Return a new HypothesisList containing all hypotheses from `self` - that have `log_prob` being greater than the given `threshold`. + with `log_prob` being greater than the given `threshold`. """ ans = HypothesisList() - for key, hyp in self._data.items(): + for _, hyp in self._data.items(): if hyp.log_prob > threshold: ans.add(hyp) # shallow copy return ans @@ -216,6 +223,106 @@ class HypothesisList(object): return ", ".join(s) +def modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # current_encoder_out is of shape (1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + dtype=torch.int64, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + # decoder_output is of shape (num_hyps, 1,1, decoder_output_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, 1, -1 + ) + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + # now logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[i] + if new_token != blank_id: + new_ys.append(new_token) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + return ys + + def beam_search( model: Transducer, encoder_out: torch.Tensor, @@ -246,7 +353,9 @@ def beam_search( device = model.device decoder_input = torch.tensor( - [blank_id] * context_size, device=device + [blank_id] * context_size, + device=device, + dtype=torch.int64, ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -283,7 +392,9 @@ def beam_search( if cached_key not in decoder_cache: decoder_input = torch.tensor( - [y_star.ys[-context_size:]], device=device + [y_star.ys[-context_size:]], + device=device, + dtype=torch.int64, ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -297,7 +408,7 @@ def beam_search( current_encoder_out, decoder_out.unsqueeze(1) ) - # TODO(fangjun): Cache the blank posterior + # TODO(fangjun): Scale the blank posterior log_prob = logits.log_softmax(dim=-1) # log_prob is (1, 1, 1, vocab_size) @@ -309,7 +420,7 @@ def beam_search( # First, process the blank symbol skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() + new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 9479d57a8..e11189f1c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -33,6 +33,15 @@ Usage: --max-duration 100 \ --decoding-method beam_search \ --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 """ @@ -46,7 +55,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search +from beam_search import beam_search, greedy_search, modified_beam_search from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -104,6 +113,7 @@ def get_parser(): help="""Possible values are: - greedy_search - beam_search + - modified_beam_search """, ) @@ -111,7 +121,8 @@ def get_parser(): "--beam-size", type=int, default=4, - help="Used only when --decoding-method is beam_search", + help="""Used only when --decoding-method is + beam_search or modified_beam_search""", ) parser.add_argument( @@ -125,7 +136,8 @@ def get_parser(): "--max-sym-per-frame", type=int, default=3, - help="Maximum number of symbols per frame", + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", ) return parser @@ -258,6 +270,10 @@ def decode_one_batch( hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) + elif params.decoding_method == "modified_beam_search": + hyp = modified_beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -391,11 +407,15 @@ def main(): params = get_params() params.update(vars(args)) - assert params.decoding_method in ("greedy_search", "beam_search") + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + ) params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.decoding_method == "beam_search": + if "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}"