Add fast beam search decoding

This commit is contained in:
pkufool 2022-03-14 10:26:04 +08:00
parent bb7f6ed6b7
commit 0e998d5f8c
3 changed files with 193 additions and 27 deletions

View File

@ -17,9 +17,86 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import k2
import torch import torch
from model import Transducer from model import Transducer
from icefall.decode import one_best_decoding
from icefall.utils import get_texts
def fast_beam_search(
decoding_graph: k2.Fsa,
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
B, T, C = encoder_out.shape
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
indivisual_streams = []
for i in range(B):
indivisual_streams.append(k2.RnntDecodingStream(decoding_graph))
decoding_streams = k2.RnntDecodingStreams(indivisual_streams, config)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
# current_encoder_out is of shape
# (shape.NumElements(), 1, encoder_out_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.detach()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
def greedy_search( def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int

View File

@ -42,6 +42,17 @@ Usage:
--max-duration 100 \ --max-duration 100 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 8 \
--max-contexts 10 \
--max-states 20
""" """
@ -51,11 +62,17 @@ from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
beam_search,
fast_beam_search,
greedy_search,
modified_beam_search,
)
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -110,6 +127,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search
""", """,
) )
@ -121,6 +139,30 @@ def get_parser():
beam_search or modified_beam_search""", beam_search or modified_beam_search""",
) )
parser.add_argument(
"--beam",
type=float,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=5,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=10,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
@ -143,6 +185,7 @@ def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
decoding_graph: k2.Fsa,
batch: dict, batch: dict,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -162,6 +205,9 @@ def decode_one_batch(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.TrivialGraph` or HLG, Used
only when --decoding_method is fast_beam_search.
batch: batch:
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
@ -184,36 +230,62 @@ def decode_one_batch(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = [] hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size): if params.decoding_method == "fast_beam_search":
# fmt: off hyp_tokens = fast_beam_search(
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] model=model,
# fmt: on decoding_graph=decoding_graph,
if params.decoding_method == "greedy_search": encoder_out=encoder_out,
hyp = greedy_search( encoder_out_lens=encoder_out_lens,
model=model, beam=params.beam,
encoder_out=encoder_out_i, max_contexts=params.max_contexts,
max_sym_per_frame=params.max_sym_per_frame, max_states=params.max_states,
) )
elif params.decoding_method == "beam_search": for hyp in sp.decode(hyp_tokens):
hyp = beam_search( hyps.append(hyp.split())
model=model, encoder_out=encoder_out_i, beam=params.beam_size else:
) batch_size = encoder_out.size(0)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search( for i in range(batch_size):
model=model, encoder_out=encoder_out_i, beam=params.beam_size # fmt: off
) encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
else: # fmt: on
raise ValueError( if params.decoding_method == "greedy_search":
f"Unsupported decoding method: {params.decoding_method}" hyp = greedy_search(
) model=model,
hyps.append(sp.decode(hyp).split()) encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else: else:
return {f"beam_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset( def decode_dataset(
@ -221,6 +293,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
decoding_graph: k2.Fsa,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -233,6 +306,9 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.TrivialGraph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -260,6 +336,7 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
@ -340,12 +417,17 @@ def main():
assert params.decoding_method in ( assert params.decoding_method in (
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
@ -388,6 +470,11 @@ def main():
model.eval() model.eval()
model.device = device model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.TrivialGraph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -408,6 +495,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph,
) )
save_results( save_results(

View File

@ -61,6 +61,7 @@ class Decoder(nn.Module):
assert context_size >= 1, context_size assert context_size >= 1, context_size
self.context_size = context_size self.context_size = context_size
self.vocab_size = vocab_size
if context_size > 1: if context_size > 1:
self.conv = nn.Conv1d( self.conv = nn.Conv1d(
in_channels=embedding_dim, in_channels=embedding_dim,