Add modified beam search.

This commit is contained in:
Fangjun Kuang 2022-01-28 16:20:42 +08:00
parent 50eb78566b
commit c3b3123b27
2 changed files with 100 additions and 5 deletions

View File

@ -312,6 +312,90 @@ def run_joiner(
return log_prob
def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1)
B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
decoder_cache: Dict[str, torch.Tensor] = {}
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
A = B
B = HypothesisList()
joint_cache: Dict[str, torch.Tensor] = {}
for hyp in A:
decoder_out = run_decoder(
ys=hyp.ys, model=model, decoder_cache=decoder_cache
)
key = "_".join(map(str, hyp.ys[-context_size:]))
key += f"-t-{t}"
log_prob = run_joiner(
key=key,
model=model,
encoder_out=current_encoder_out,
decoder_out=decoder_out,
encoder_out_len=encoder_out_len,
decoder_out_len=decoder_out_len,
joint_cache=joint_cache,
)
log_prob = log_prob.cpu().tolist()
for i, v in enumerate(log_prob):
if i == blank_id:
# Use [:] to make a copy
new_ys = hyp.ys[:]
else:
new_ys = hyp.ys + [i]
new_hyp = Hypothesis(ys=new_ys, log_prob=hyp.log_prob + v)
B.add(new_hyp)
B = B.topk(beam)
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
def beam_search(
model: Transducer,
encoder_out: torch.Tensor,

View File

@ -46,7 +46,7 @@ import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search
from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
@ -104,6 +104,7 @@ def get_parser():
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
""",
)
@ -111,7 +112,8 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="Used only when --decoding-method is beam_search",
help="""Used only when --decoding-method is
beam_search or modified_beam_search""",
)
parser.add_argument(
@ -125,7 +127,8 @@ def get_parser():
"--max-sym-per-frame",
type=int,
default=3,
help="Maximum number of symbols per frame",
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
return parser
@ -256,6 +259,10 @@ def decode_one_batch(
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
@ -389,11 +396,15 @@ def main():
params = get_params()
params.update(vars(args))
assert params.decoding_method in ("greedy_search", "beam_search")
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search":
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"