Update beam search to support max/log_add in selecting duplicate hyps.

This commit is contained in:
Fangjun Kuang 2022-03-28 12:33:58 +08:00
parent 395a3f952b
commit 52f1f6775d
4 changed files with 288 additions and 19 deletions

View File

@ -21,6 +21,153 @@ import k2
import torch import torch
from model import Transducer from model import Transducer
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
def fast_beam_search(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
use_max: bool = False,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
use_max:
True to use max operation to select the hypothesis with the largest
log_prob when there are duplicate hypotheses; False to use log-add.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
B, T, C = encoder_out.shape
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(k2.RnntDecodingStream(decoding_graph))
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
tmp_len = torch.tensor([1])
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
# current_encoder_out is of shape
# (shape.NumElements(), 1, encoder_out_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1)
)
# fmt: on
logits = model.joiner(
current_encoder_out,
decoder_out,
tmp_len.expand(decoder_out.size(0)),
tmp_len.expand(decoder_out.size(0)),
)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
if use_max:
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
else:
num_paths = 20
use_double_scores = True
nbest_scale = 0.8
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# The following code is modified from nbest.intersect()
word_fsa = k2.invert(nbest.fsa)
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
word_fsa
)
path_to_utt_map = nbest.shape.row_ids(1)
if hasattr(lattice, "aux_labels"):
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
else:
inv_lattice = k2.arc_sort(lattice)
if inv_lattice.shape[0] == 1:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=torch.zeros_like(path_to_utt_map),
sorted_match_a=True,
)
else:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_utt_map,
sorted_match_a=True,
)
# path_lattice has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
tot_scores = path_lattice.get_tot_scores(
use_double_scores=True, log_semiring=True
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
best_hyp_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
hyps = get_texts(best_path)
return hyps
def greedy_search( def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
@ -203,7 +350,7 @@ class HypothesisList(object):
def data(self) -> Dict[str, Hypothesis]: def data(self) -> Dict[str, Hypothesis]:
return self._data return self._data
def add(self, hyp: Hypothesis) -> None: def add(self, hyp: Hypothesis, use_max: bool = False) -> None:
"""Add a Hypothesis to `self`. """Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using If `hyp` already exists in `self`, its probability is updated using
@ -212,13 +359,20 @@ class HypothesisList(object):
Args: Args:
hyp: hyp:
The hypothesis to be added. The hypothesis to be added.
use_max:
True to select the hypothesis with the larger log_prob in case there
already exists a hypothesis whose `ys` equals to `hyp.ys`.
False to use log_add.
""" """
key = hyp.key key = hyp.key
if key in self: if key in self:
old_hyp = self._data[key] # shallow copy old_hyp = self._data[key] # shallow copy
torch.logaddexp( if use_max:
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
) else:
torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
else: else:
self._data[key] = hyp self._data[key] = hyp
@ -415,6 +569,7 @@ def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
use_max: bool = False,
) -> List[List[int]]: ) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcodded. """Beam search in batch mode with --max-sym-per-frame=1 being hardcodded.
@ -425,6 +580,10 @@ def modified_beam_search(
Output from the encoder. Its shape is (N, T, C). Output from the encoder. Its shape is (N, T, C).
beam: beam:
Number of active paths during the beam search. Number of active paths during the beam search.
use_max:
If True, it uses max operation to select the hypothesis with the
larger log_prob in case two hypotheses have the same token sequences.
If False, use log add.
Returns: Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance. for the i-th utterance.
@ -443,7 +602,8 @@ def modified_beam_search(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
) ),
use_max=use_max,
) )
encoder_out_len = torch.tensor([1]) encoder_out_len = torch.tensor([1])
@ -519,7 +679,7 @@ def modified_beam_search(
new_log_prob = topk_log_probs[k] new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp) B[i].add(new_hyp, use_max=use_max)
best_hyps = [b.get_most_probable(length_norm=True) for b in B] best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps] ans = [h.ys[context_size:] for h in best_hyps]
@ -531,6 +691,7 @@ def _deprecated_modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
use_max: bool = False,
) -> List[int]: ) -> List[int]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -545,6 +706,10 @@ def _deprecated_modified_beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam: beam:
Beam size. Beam size.
use_max:
If True, it uses max operation to select the hypothesis with the
larger log_prob in case two hypotheses have the same token sequences.
If False, use log add.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -565,7 +730,8 @@ def _deprecated_modified_beam_search(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
) ),
use_max=use_max,
) )
encoder_out_len = torch.tensor([1]) encoder_out_len = torch.tensor([1])
@ -624,7 +790,7 @@ def _deprecated_modified_beam_search(
new_ys.append(new_token) new_ys.append(new_token)
new_log_prob = topk_log_probs[i] new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B.add(new_hyp) B.add(new_hyp, use_max=use_max)
best_hyp = B.get_most_probable(length_norm=True) best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
@ -636,6 +802,7 @@ def beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
use_max: bool = False,
) -> List[int]: ) -> List[int]:
""" """
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -649,6 +816,10 @@ def beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam: beam:
Beam size. Beam size.
use_max:
If True, it uses max operation to select the hypothesis with the
larger log_prob in case two hypotheses have the same token sequences.
If False, use log add.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -675,7 +846,8 @@ def beam_search(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
) ),
use_max=use_max,
) )
max_sym_per_utt = 20000 max_sym_per_utt = 20000
@ -721,7 +893,10 @@ def beam_search(
new_y_star_log_prob = y_star.log_prob + skip_log_prob new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys # ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) B.add(
Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob),
use_max=use_max,
)
# Second, process other non-blank labels # Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1) values, indices = log_prob.topk(beam + 1)
@ -733,7 +908,10 @@ def beam_search(
new_ys = y_star.ys + [i] new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + values[idx] new_log_prob = y_star.log_prob + values[idx]
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) A.add(
Hypothesis(ys=new_ys, log_prob=new_log_prob),
use_max=use_max,
)
# Check whether B contains more than "beam" elements more probable # Check whether B contains more than "beam" elements more probable
# than the most probable in A # than the most probable in A

View File

@ -22,7 +22,8 @@ Usage:
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless/exp \
--max-duration 100 \ --max-duration 1000 \
--max-sym-per-frame 1 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search
@ -30,7 +31,7 @@ Usage:
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless/exp \
--max-duration 100 \ --max-duration 1000 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
@ -39,9 +40,20 @@ Usage:
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless/exp \
--max-duration 100 \ --max-duration 1000 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search
./transducer_stateless/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 1000 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
""" """
@ -49,14 +61,16 @@ import argparse
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -68,6 +82,7 @@ from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool,
write_error_stats, write_error_stats,
) )
@ -115,6 +130,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search
""", """,
) )
@ -126,6 +142,32 @@ def get_parser():
beam_search or modified_beam_search""", beam_search or modified_beam_search""",
) )
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
@ -141,6 +183,17 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument(
"--use-max",
type=str2bool,
default=False,
help="""If True, use max-op to select the hypothesis that have the
max log_prob in case of duplicate hypotheses.
If False, use log_add.
Used only for beam_search, modified_beam_search, and fast_beam_search
""",
)
return parser return parser
@ -149,6 +202,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -190,7 +244,18 @@ def decode_one_batch(
) )
hyp_list: List[List[int]] = [] hyp_list: List[List[int]] = []
if ( if params.decoding_method == "fast_beam_search":
hyp_list = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
use_max=params.use_max,
)
elif (
params.decoding_method == "greedy_search" params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1 and params.max_sym_per_frame == 1
): ):
@ -203,6 +268,7 @@ def decode_one_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
beam=params.beam_size, beam=params.beam_size,
use_max=params.use_max,
) )
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -221,6 +287,7 @@ def decode_one_batch(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
use_max=params.use_max,
) )
else: else:
raise ValueError( raise ValueError(
@ -232,6 +299,14 @@ def decode_one_batch(
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else: else:
return {f"beam_{params.beam_size}": hyps} return {f"beam_{params.beam_size}": hyps}
@ -241,6 +316,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -281,6 +357,7 @@ def decode_dataset(
model=model, model=model,
sp=sp, sp=sp,
batch=batch, batch=batch,
decoding_graph=decoding_graph,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -358,15 +435,22 @@ def main():
params.update(vars(args)) params.update(vars(args))
assert params.decoding_method in ( assert params.decoding_method in (
"greedy_search",
"beam_search", "beam_search",
"fast_beam_search",
"greedy_search",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method: if "fast_beam_search" == params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-use-max-{params.use_max}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
params.suffix += f"-use-max-{params.use_max}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -408,6 +492,11 @@ def main():
model.eval() model.eval()
model.device = device model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -428,6 +517,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph,
) )
save_results( save_results(

View File

@ -58,6 +58,7 @@ class Decoder(nn.Module):
padding_idx=blank_id, padding_idx=blank_id,
) )
self.blank_id = blank_id self.blank_id = blank_id
self.vocab_size = vocab_size
assert context_size >= 1, context_size assert context_size >= 1, context_size
self.context_size = context_size self.context_size = context_size

View File

@ -218,7 +218,7 @@ class Nbest(object):
# word_seq is a k2.RaggedTensor sharing the same shape as `path` # word_seq is a k2.RaggedTensor sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s. # but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1. # The last entry in each sublist is -1.
# It axes is [utt][path][word_id] # Its axes are [utt][path][word_id]
if isinstance(lattice.aux_labels, torch.Tensor): if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path) word_seq = k2.ragged.index(lattice.aux_labels, path)
else: else: