diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py index 5151f6243..1a1d9ea9d 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py @@ -18,7 +18,6 @@ from dataclasses import dataclass from typing import Dict, List, Optional -import numpy as np import torch from model import Transducer @@ -141,7 +140,7 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] - old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) + old_hyp.log_prob = torch.logaddexp(old_hyp.log_prob, hyp.log_prob) else: self._data[key] = hyp @@ -211,6 +210,106 @@ class HypothesisList(object): return ", ".join(s) +def modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + unk_id = model.decoder.unk_id + context_size = model.decoder.context_size + + device = model.device + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + dtype=torch.int64, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + # decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, 1, -1 + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + # now logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[i] + if new_token != blank_id and new_token != unk_id: + new_ys.append(new_token) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + return ys + + def beam_search( model: Transducer, encoder_out: torch.Tensor, diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py index 0814b3f1f..664b1d36a 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py @@ -47,7 +47,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import TedLiumAsrDataModule -from beam_search import beam_search, greedy_search +from beam_search import beam_search, greedy_search, modified_beam_search from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -105,6 +105,7 @@ def get_parser(): help="""Possible values are: - greedy_search - beam_search + - modified_beam_search """, ) @@ -262,6 +263,10 @@ def decode_one_batch( hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) + elif params.decoding_method == "modified_beam_search": + hyp = modified_beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -398,6 +403,7 @@ def main(): assert params.decoding_method in ( "greedy_search", "beam_search", + "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method