Update beam search to support max/log_add in selecting duplicate hyps.

This commit is contained in:
Fangjun Kuang 2022-03-28 12:33:58 +08:00
parent 395a3f952b
commit 52f1f6775d
4 changed files with 288 additions and 19 deletions

View File

@ -21,6 +21,153 @@ import k2
import torch
from model import Transducer
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
def fast_beam_search(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
use_max: bool = False,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
use_max:
True to use max operation to select the hypothesis with the largest
log_prob when there are duplicate hypotheses; False to use log-add.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
B, T, C = encoder_out.shape
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(k2.RnntDecodingStream(decoding_graph))
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
tmp_len = torch.tensor([1])
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
# current_encoder_out is of shape
# (shape.NumElements(), 1, encoder_out_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1)
)
# fmt: on
logits = model.joiner(
current_encoder_out,
decoder_out,
tmp_len.expand(decoder_out.size(0)),
tmp_len.expand(decoder_out.size(0)),
)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
if use_max:
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
else:
num_paths = 20
use_double_scores = True
nbest_scale = 0.8
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# The following code is modified from nbest.intersect()
word_fsa = k2.invert(nbest.fsa)
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
word_fsa
)
path_to_utt_map = nbest.shape.row_ids(1)
if hasattr(lattice, "aux_labels"):
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
else:
inv_lattice = k2.arc_sort(lattice)
if inv_lattice.shape[0] == 1:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=torch.zeros_like(path_to_utt_map),
sorted_match_a=True,
)
else:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_utt_map,
sorted_match_a=True,
)
# path_lattice has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
tot_scores = path_lattice.get_tot_scores(
use_double_scores=True, log_semiring=True
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
best_hyp_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
hyps = get_texts(best_path)
return hyps
def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
@ -203,7 +350,7 @@ class HypothesisList(object):
def data(self) -> Dict[str, Hypothesis]:
return self._data
def add(self, hyp: Hypothesis) -> None:
def add(self, hyp: Hypothesis, use_max: bool = False) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
@ -212,10 +359,17 @@ class HypothesisList(object):
Args:
hyp:
The hypothesis to be added.
use_max:
True to select the hypothesis with the larger log_prob in case there
already exists a hypothesis whose `ys` equals to `hyp.ys`.
False to use log_add.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key] # shallow copy
if use_max:
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
else:
torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
@ -415,6 +569,7 @@ def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
use_max: bool = False,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcodded.
@ -425,6 +580,10 @@ def modified_beam_search(
Output from the encoder. Its shape is (N, T, C).
beam:
Number of active paths during the beam search.
use_max:
If True, it uses max operation to select the hypothesis with the
larger log_prob in case two hypotheses have the same token sequences.
If False, use log add.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
@ -443,7 +602,8 @@ def modified_beam_search(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
),
use_max=use_max,
)
encoder_out_len = torch.tensor([1])
@ -519,7 +679,7 @@ def modified_beam_search(
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
B[i].add(new_hyp, use_max=use_max)
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
@ -531,6 +691,7 @@ def _deprecated_modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
use_max: bool = False,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
@ -545,6 +706,10 @@ def _deprecated_modified_beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
use_max:
If True, it uses max operation to select the hypothesis with the
larger log_prob in case two hypotheses have the same token sequences.
If False, use log add.
Returns:
Return the decoded result.
"""
@ -565,7 +730,8 @@ def _deprecated_modified_beam_search(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
),
use_max=use_max,
)
encoder_out_len = torch.tensor([1])
@ -624,7 +790,7 @@ def _deprecated_modified_beam_search(
new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B.add(new_hyp)
B.add(new_hyp, use_max=use_max)
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
@ -636,6 +802,7 @@ def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
use_max: bool = False,
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -649,6 +816,10 @@ def beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
use_max:
If True, it uses max operation to select the hypothesis with the
larger log_prob in case two hypotheses have the same token sequences.
If False, use log add.
Returns:
Return the decoded result.
"""
@ -675,7 +846,8 @@ def beam_search(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
),
use_max=use_max,
)
max_sym_per_utt = 20000
@ -721,7 +893,10 @@ def beam_search(
new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
B.add(
Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob),
use_max=use_max,
)
# Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1)
@ -733,7 +908,10 @@ def beam_search(
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + values[idx]
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
A.add(
Hypothesis(ys=new_ys, log_prob=new_log_prob),
use_max=use_max,
)
# Check whether B contains more than "beam" elements more probable
# than the most probable in A

View File

@ -22,7 +22,8 @@ Usage:
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 100 \
--max-duration 1000 \
--max-sym-per-frame 1 \
--decoding-method greedy_search
(2) beam search
@ -30,7 +31,7 @@ Usage:
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 100 \
--max-duration 1000 \
--decoding-method beam_search \
--beam-size 4
@ -39,9 +40,20 @@ Usage:
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 100 \
--max-duration 1000 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./transducer_stateless/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 1000 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
@ -49,14 +61,16 @@ import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -68,6 +82,7 @@ from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
@ -115,6 +130,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
@ -126,6 +142,32 @@ def get_parser():
beam_search or modified_beam_search""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
@ -141,6 +183,17 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--use-max",
type=str2bool,
default=False,
help="""If True, use max-op to select the hypothesis that have the
max log_prob in case of duplicate hypotheses.
If False, use log_add.
Used only for beam_search, modified_beam_search, and fast_beam_search
""",
)
return parser
@ -149,6 +202,7 @@ def decode_one_batch(
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -190,7 +244,18 @@ def decode_one_batch(
)
hyp_list: List[List[int]] = []
if (
if params.decoding_method == "fast_beam_search":
hyp_list = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
use_max=params.use_max,
)
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
@ -203,6 +268,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
use_max=params.use_max,
)
else:
batch_size = encoder_out.size(0)
@ -221,6 +287,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
use_max=params.use_max,
)
else:
raise ValueError(
@ -232,6 +299,14 @@ def decode_one_batch(
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_{params.beam_size}": hyps}
@ -241,6 +316,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -281,6 +357,7 @@ def decode_dataset(
model=model,
sp=sp,
batch=batch,
decoding_graph=decoding_graph,
)
for name, hyps in hyps_dict.items():
@ -358,15 +435,22 @@ def main():
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"greedy_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
if "fast_beam_search" == params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-use-max-{params.use_max}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
params.suffix += f"-use-max-{params.use_max}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -408,6 +492,11 @@ def main():
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -428,6 +517,7 @@ def main():
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(

View File

@ -58,6 +58,7 @@ class Decoder(nn.Module):
padding_idx=blank_id,
)
self.blank_id = blank_id
self.vocab_size = vocab_size
assert context_size >= 1, context_size
self.context_size = context_size

View File

@ -218,7 +218,7 @@ class Nbest(object):
# word_seq is a k2.RaggedTensor sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1.
# It axes is [utt][path][word_id]
# Its axes are [utt][path][word_id]
if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else: