Add fast beam search decoding

This commit is contained in:
pkufool 2022-03-14 10:26:04 +08:00
parent bb7f6ed6b7
commit 0e998d5f8c
3 changed files with 193 additions and 27 deletions

View File

@ -17,9 +17,86 @@
from dataclasses import dataclass
from typing import Dict, List, Optional
import k2
import torch
from model import Transducer
from icefall.decode import one_best_decoding
from icefall.utils import get_texts
def fast_beam_search(
decoding_graph: k2.Fsa,
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
B, T, C = encoder_out.shape
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
indivisual_streams = []
for i in range(B):
indivisual_streams.append(k2.RnntDecodingStream(decoding_graph))
decoding_streams = k2.RnntDecodingStreams(indivisual_streams, config)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
# current_encoder_out is of shape
# (shape.NumElements(), 1, encoder_out_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.detach()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int

View File

@ -42,6 +42,17 @@ Usage:
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 8 \
--max-contexts 10 \
--max-states 20
"""
@ -51,11 +62,17 @@ from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search
from beam_search import (
beam_search,
fast_beam_search,
greedy_search,
modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -110,6 +127,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
@ -121,6 +139,30 @@ def get_parser():
beam_search or modified_beam_search""",
)
parser.add_argument(
"--beam",
type=float,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=5,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=10,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
@ -143,6 +185,7 @@ def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: k2.Fsa,
batch: dict,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
@ -162,6 +205,9 @@ def decode_one_batch(
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.TrivialGraph` or HLG, Used
only when --decoding_method is fast_beam_search.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
@ -184,36 +230,62 @@ def decode_one_batch(
x=feature, x_lens=feature_lens
)
hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_{params.beam_size}": hyps}
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
@ -221,6 +293,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: k2.Fsa,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -233,6 +306,9 @@ def decode_dataset(
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.TrivialGraph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -260,6 +336,7 @@ def decode_dataset(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
)
@ -340,12 +417,17 @@ def main():
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
@ -388,6 +470,11 @@ def main():
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.TrivialGraph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -408,6 +495,7 @@ def main():
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(

View File

@ -61,6 +61,7 @@ class Decoder(nn.Module):
assert context_size >= 1, context_size
self.context_size = context_size
self.vocab_size = vocab_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=embedding_dim,