Add modified-beam-search and fast-beam-search

This commit is contained in:
pkufool 2022-03-21 19:31:51 +08:00
parent 7a3e88d2d3
commit 7896baea14
3 changed files with 332 additions and 44 deletions

View File

@ -17,10 +17,91 @@
from dataclasses import dataclass
from typing import Dict, List, Optional
import numpy as np
import k2
import torch
from model import Transducer
from icefall.decode import one_best_decoding
from icefall.utils import get_texts
def fast_beam_search(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
B, T, C = encoder_out.shape
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(k2.RnntDecodingStream(decoding_graph))
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
# current_encoder_out is of shape
# (shape.NumElements(), 1, encoder_out_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
@ -48,7 +129,7 @@ def greedy_search(
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
[blank_id] * context_size, device=device, dtype=torch.int64
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
@ -103,8 +184,9 @@ class Hypothesis:
# Newly predicted tokens are appended to `ys`.
ys: List[int]
# The log prob of ys
log_prob: float
# The log prob of ys.
# It contains only one entry.
log_prob: torch.Tensor
@property
def key(self) -> str:
@ -113,7 +195,7 @@ class Hypothesis:
class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
"""
Args:
data:
@ -125,10 +207,10 @@ class HypothesisList(object):
self._data = data
@property
def data(self):
def data(self) -> Dict[str, Hypothesis]:
return self._data
def add(self, hyp: Hypothesis):
def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
@ -140,8 +222,10 @@ class HypothesisList(object):
"""
key = hyp.key
if key in self:
old_hyp = self._data[key]
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob)
old_hyp = self._data[key] # shallow copy
torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
else:
self._data[key] = hyp
@ -153,7 +237,8 @@ class HypothesisList(object):
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
Returns:
Return the hypothesis that has the largest `log_prob`.
"""
if length_norm:
return max(
@ -165,6 +250,9 @@ class HypothesisList(object):
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args:
hyp:
The hypothesis to be removed from `self`.
@ -175,7 +263,7 @@ class HypothesisList(object):
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(self, threshold: float) -> "HypothesisList":
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
@ -183,10 +271,10 @@ class HypothesisList(object):
Returns:
Return a new HypothesisList containing all hypotheses from `self`
that have `log_prob` being greater than the given `threshold`.
with `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
for key, hyp in self._data.items():
for _, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
return ans
@ -216,6 +304,106 @@ class HypothesisList(object):
return ", ".join(s)
def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
B = HypothesisList()
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# current_encoder_out is of shape (1, 1, 1, encoder_out_dim)
# fmt: on
A = list(B)
B = HypothesisList()
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
# ys_log_probs is of shape (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyp in A],
device=device,
dtype=torch.int64,
)
# decoder_input is of shape (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
# decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim)
current_encoder_out = current_encoder_out.expand(
decoder_out.size(0), 1, 1, -1
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
# now logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1)
log_probs.add_(ys_log_probs)
log_probs = log_probs.reshape(-1)
topk_log_probs, topk_indexes = log_probs.topk(beam)
# topk_hyp_indexes are indexes into `A`
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[i]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B.add(new_hyp)
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
@ -246,7 +434,9 @@ def beam_search(
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
[blank_id] * context_size,
device=device,
dtype=torch.int64,
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
@ -283,7 +473,9 @@ def beam_search(
if cached_key not in decoder_cache:
decoder_input = torch.tensor(
[y_star.ys[-context_size:]], device=device
[y_star.ys[-context_size:]],
device=device,
dtype=torch.int64,
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
@ -309,7 +501,7 @@ def beam_search(
# First, process the blank symbol
skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))

View File

@ -20,12 +20,18 @@ import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from beam_search import beam_search, greedy_search
from beam_search import (
beam_search,
fast_beam_search,
greedy_search,
modified_beam_search,
)
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
@ -85,6 +91,8 @@ def get_parser():
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
@ -92,7 +100,35 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="Used only when --decoding-method is beam_search",
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
@ -102,12 +138,14 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="Maximum number of symbols per frame",
)
parser.add_argument(
"--export",
type=str2bool,
@ -192,6 +230,7 @@ def decode_one_batch(
model: nn.Module,
lexicon: Lexicon,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -208,12 +247,15 @@ def decode_one_batch(
It's the return value of :func:`get_params`.
model:
The neural model.
lexicon:
It contains the token symbol table and the word symbol table.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains the token symbol table and the word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -232,32 +274,62 @@ def decode_one_batch(
x=feature, x_lens=feature_lens
)
hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[i] for i in hyp])
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in hyp_tokens:
hyps.append([lexicon.token_table[i] for i in hyp])
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[i] for i in hyp])
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_{params.beam_size}": hyps}
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
@ -265,6 +337,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -275,6 +348,11 @@ def decode_dataset(
It is returned by :func:`get_params`.
model:
The neural model.
lexicon:
It contains the token symbol table and the word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -303,6 +381,7 @@ def decode_dataset(
model=model,
lexicon=lexicon,
batch=batch,
decoding_graph=decoding_graph,
)
for name, hyps in hyps_dict.items():
@ -383,11 +462,21 @@ def main():
params = get_params()
params.update(vars(args))
assert params.decoding_method in ("greedy_search", "beam_search")
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
"fast_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search":
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
@ -435,6 +524,11 @@ def main():
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -451,6 +545,7 @@ def main():
params=params,
model=model,
lexicon=lexicon,
decoding_graph=decoding_graph,
)
save_results(

View File

@ -58,6 +58,7 @@ class Decoder(nn.Module):
padding_idx=blank_id,
)
self.blank_id = blank_id
self.vocab_size = vocab_size
assert context_size >= 1, context_size
self.context_size = context_size