mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 06:04:18 +00:00
Update beam search to support max/log_add in selecting duplicate hyps.
This commit is contained in:
parent
395a3f952b
commit
52f1f6775d
@ -21,6 +21,153 @@ import k2
|
||||
import torch
|
||||
from model import Transducer
|
||||
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
def fast_beam_search(
|
||||
model: Transducer,
|
||||
decoding_graph: k2.Fsa,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
beam: float,
|
||||
max_states: int,
|
||||
max_contexts: int,
|
||||
use_max: bool = False,
|
||||
) -> List[List[int]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
decoding_graph:
|
||||
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder.
|
||||
encoder_out_lens:
|
||||
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||
before padding.
|
||||
beam:
|
||||
Beam value, similar to the beam used in Kaldi..
|
||||
max_states:
|
||||
Max states per stream per frame.
|
||||
max_contexts:
|
||||
Max contexts pre stream per frame.
|
||||
use_max:
|
||||
True to use max operation to select the hypothesis with the largest
|
||||
log_prob when there are duplicate hypotheses; False to use log-add.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
context_size = model.decoder.context_size
|
||||
vocab_size = model.decoder.vocab_size
|
||||
|
||||
B, T, C = encoder_out.shape
|
||||
|
||||
config = k2.RnntDecodingConfig(
|
||||
vocab_size=vocab_size,
|
||||
decoder_history_len=context_size,
|
||||
beam=beam,
|
||||
max_contexts=max_contexts,
|
||||
max_states=max_states,
|
||||
)
|
||||
individual_streams = []
|
||||
for i in range(B):
|
||||
individual_streams.append(k2.RnntDecodingStream(decoding_graph))
|
||||
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
|
||||
|
||||
tmp_len = torch.tensor([1])
|
||||
|
||||
for t in range(T):
|
||||
# shape is a RaggedShape of shape (B, context)
|
||||
# contexts is a Tensor of shape (shape.NumElements(), context_size)
|
||||
shape, contexts = decoding_streams.get_contexts()
|
||||
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
|
||||
contexts = contexts.to(torch.int64)
|
||||
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
|
||||
decoder_out = model.decoder(contexts, need_pad=False)
|
||||
# current_encoder_out is of shape
|
||||
# (shape.NumElements(), 1, encoder_out_dim)
|
||||
# fmt: off
|
||||
current_encoder_out = torch.index_select(
|
||||
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1)
|
||||
)
|
||||
# fmt: on
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
tmp_len.expand(decoder_out.size(0)),
|
||||
tmp_len.expand(decoder_out.size(0)),
|
||||
)
|
||||
log_probs = logits.log_softmax(dim=-1)
|
||||
decoding_streams.advance(log_probs)
|
||||
decoding_streams.terminate_and_flush_to_streams()
|
||||
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
||||
|
||||
if use_max:
|
||||
best_path = one_best_decoding(lattice)
|
||||
hyps = get_texts(best_path)
|
||||
return hyps
|
||||
else:
|
||||
num_paths = 20
|
||||
use_double_scores = True
|
||||
nbest_scale = 0.8
|
||||
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=use_double_scores,
|
||||
nbest_scale=nbest_scale,
|
||||
)
|
||||
# The following code is modified from nbest.intersect()
|
||||
word_fsa = k2.invert(nbest.fsa)
|
||||
if hasattr(lattice, "aux_labels"):
|
||||
# delete token IDs as it is not needed
|
||||
del word_fsa.aux_labels
|
||||
word_fsa.scores.zero_()
|
||||
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
|
||||
word_fsa
|
||||
)
|
||||
path_to_utt_map = nbest.shape.row_ids(1)
|
||||
|
||||
if hasattr(lattice, "aux_labels"):
|
||||
# lattice has token IDs as labels and word IDs as aux_labels.
|
||||
# inv_lattice has word IDs as labels and token IDs as aux_labels
|
||||
inv_lattice = k2.invert(lattice)
|
||||
inv_lattice = k2.arc_sort(inv_lattice)
|
||||
else:
|
||||
inv_lattice = k2.arc_sort(lattice)
|
||||
|
||||
if inv_lattice.shape[0] == 1:
|
||||
path_lattice = k2.intersect_device(
|
||||
inv_lattice,
|
||||
word_fsa_with_epsilon_loops,
|
||||
b_to_a_map=torch.zeros_like(path_to_utt_map),
|
||||
sorted_match_a=True,
|
||||
)
|
||||
else:
|
||||
path_lattice = k2.intersect_device(
|
||||
inv_lattice,
|
||||
word_fsa_with_epsilon_loops,
|
||||
b_to_a_map=path_to_utt_map,
|
||||
sorted_match_a=True,
|
||||
)
|
||||
|
||||
# path_lattice has word IDs as labels and token IDs as aux_labels
|
||||
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
||||
|
||||
tot_scores = path_lattice.get_tot_scores(
|
||||
use_double_scores=True, log_semiring=True
|
||||
)
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
best_hyp_indexes = ragged_tot_scores.argmax()
|
||||
|
||||
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
|
||||
hyps = get_texts(best_path)
|
||||
return hyps
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
||||
@ -203,7 +350,7 @@ class HypothesisList(object):
|
||||
def data(self) -> Dict[str, Hypothesis]:
|
||||
return self._data
|
||||
|
||||
def add(self, hyp: Hypothesis) -> None:
|
||||
def add(self, hyp: Hypothesis, use_max: bool = False) -> None:
|
||||
"""Add a Hypothesis to `self`.
|
||||
|
||||
If `hyp` already exists in `self`, its probability is updated using
|
||||
@ -212,13 +359,20 @@ class HypothesisList(object):
|
||||
Args:
|
||||
hyp:
|
||||
The hypothesis to be added.
|
||||
use_max:
|
||||
True to select the hypothesis with the larger log_prob in case there
|
||||
already exists a hypothesis whose `ys` equals to `hyp.ys`.
|
||||
False to use log_add.
|
||||
"""
|
||||
key = hyp.key
|
||||
if key in self:
|
||||
old_hyp = self._data[key] # shallow copy
|
||||
torch.logaddexp(
|
||||
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
||||
)
|
||||
if use_max:
|
||||
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
|
||||
else:
|
||||
torch.logaddexp(
|
||||
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
||||
)
|
||||
else:
|
||||
self._data[key] = hyp
|
||||
|
||||
@ -415,6 +569,7 @@ def modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
use_max: bool = False,
|
||||
) -> List[List[int]]:
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcodded.
|
||||
|
||||
@ -425,6 +580,10 @@ def modified_beam_search(
|
||||
Output from the encoder. Its shape is (N, T, C).
|
||||
beam:
|
||||
Number of active paths during the beam search.
|
||||
use_max:
|
||||
If True, it uses max operation to select the hypothesis with the
|
||||
larger log_prob in case two hypotheses have the same token sequences.
|
||||
If False, use log add.
|
||||
Returns:
|
||||
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||
for the i-th utterance.
|
||||
@ -443,7 +602,8 @@ def modified_beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
),
|
||||
use_max=use_max,
|
||||
)
|
||||
|
||||
encoder_out_len = torch.tensor([1])
|
||||
@ -519,7 +679,7 @@ def modified_beam_search(
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
B[i].add(new_hyp)
|
||||
B[i].add(new_hyp, use_max=use_max)
|
||||
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
ans = [h.ys[context_size:] for h in best_hyps]
|
||||
@ -531,6 +691,7 @@ def _deprecated_modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
use_max: bool = False,
|
||||
) -> List[int]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
@ -545,6 +706,10 @@ def _deprecated_modified_beam_search(
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
beam:
|
||||
Beam size.
|
||||
use_max:
|
||||
If True, it uses max operation to select the hypothesis with the
|
||||
larger log_prob in case two hypotheses have the same token sequences.
|
||||
If False, use log add.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
@ -565,7 +730,8 @@ def _deprecated_modified_beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
),
|
||||
use_max=use_max,
|
||||
)
|
||||
|
||||
encoder_out_len = torch.tensor([1])
|
||||
@ -624,7 +790,7 @@ def _deprecated_modified_beam_search(
|
||||
new_ys.append(new_token)
|
||||
new_log_prob = topk_log_probs[i]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
B.add(new_hyp)
|
||||
B.add(new_hyp, use_max=use_max)
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||
@ -636,6 +802,7 @@ def beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
use_max: bool = False,
|
||||
) -> List[int]:
|
||||
"""
|
||||
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
||||
@ -649,6 +816,10 @@ def beam_search(
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
beam:
|
||||
Beam size.
|
||||
use_max:
|
||||
If True, it uses max operation to select the hypothesis with the
|
||||
larger log_prob in case two hypotheses have the same token sequences.
|
||||
If False, use log add.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
@ -675,7 +846,8 @@ def beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
),
|
||||
use_max=use_max,
|
||||
)
|
||||
|
||||
max_sym_per_utt = 20000
|
||||
@ -721,7 +893,10 @@ def beam_search(
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||
|
||||
# ys[:] returns a copy of ys
|
||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
||||
B.add(
|
||||
Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob),
|
||||
use_max=use_max,
|
||||
)
|
||||
|
||||
# Second, process other non-blank labels
|
||||
values, indices = log_prob.topk(beam + 1)
|
||||
@ -733,7 +908,10 @@ def beam_search(
|
||||
new_ys = y_star.ys + [i]
|
||||
|
||||
new_log_prob = y_star.log_prob + values[idx]
|
||||
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
||||
A.add(
|
||||
Hypothesis(ys=new_ys, log_prob=new_log_prob),
|
||||
use_max=use_max,
|
||||
)
|
||||
|
||||
# Check whether B contains more than "beam" elements more probable
|
||||
# than the most probable in A
|
||||
|
@ -22,7 +22,8 @@ Usage:
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--exp-dir ./transducer_stateless/exp \
|
||||
--max-duration 100 \
|
||||
--max-duration 1000 \
|
||||
--max-sym-per-frame 1 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search
|
||||
@ -30,7 +31,7 @@ Usage:
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--exp-dir ./transducer_stateless/exp \
|
||||
--max-duration 100 \
|
||||
--max-duration 1000 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
@ -39,9 +40,20 @@ Usage:
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--exp-dir ./transducer_stateless/exp \
|
||||
--max-duration 100 \
|
||||
--max-duration 1000 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search
|
||||
./transducer_stateless/decode.py \
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--exp-dir ./transducer_stateless/exp \
|
||||
--max-duration 1000 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
"""
|
||||
|
||||
|
||||
@ -49,14 +61,16 @@ import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
@ -68,6 +82,7 @@ from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
@ -115,6 +130,7 @@ def get_parser():
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
@ -126,6 +142,32 @@ def get_parser():
|
||||
beam_search or modified_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
@ -141,6 +183,17 @@ def get_parser():
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-max",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If True, use max-op to select the hypothesis that have the
|
||||
max log_prob in case of duplicate hypotheses.
|
||||
If False, use log_add.
|
||||
Used only for beam_search, modified_beam_search, and fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -149,6 +202,7 @@ def decode_one_batch(
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -190,7 +244,18 @@ def decode_one_batch(
|
||||
)
|
||||
hyp_list: List[List[int]] = []
|
||||
|
||||
if (
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_list = fast_beam_search(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
use_max=params.use_max,
|
||||
)
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
@ -203,6 +268,7 @@ def decode_one_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
beam=params.beam_size,
|
||||
use_max=params.use_max,
|
||||
)
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
@ -221,6 +287,7 @@ def decode_one_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
use_max=params.use_max,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -232,6 +299,14 @@ def decode_one_batch(
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
return {
|
||||
(
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
): hyps
|
||||
}
|
||||
else:
|
||||
return {f"beam_{params.beam_size}": hyps}
|
||||
|
||||
@ -241,6 +316,7 @@ def decode_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -281,6 +357,7 @@ def decode_dataset(
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
@ -358,15 +435,22 @@ def main():
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"greedy_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if "beam_search" in params.decoding_method:
|
||||
if "fast_beam_search" == params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
params.suffix += f"-use-max-{params.use_max}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
params.suffix += f"-use-max-{params.use_max}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -408,6 +492,11 @@ def main():
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -428,6 +517,7 @@ def main():
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
@ -58,6 +58,7 @@ class Decoder(nn.Module):
|
||||
padding_idx=blank_id,
|
||||
)
|
||||
self.blank_id = blank_id
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
self.context_size = context_size
|
||||
|
@ -218,7 +218,7 @@ class Nbest(object):
|
||||
# word_seq is a k2.RaggedTensor sharing the same shape as `path`
|
||||
# but it contains word IDs. Note that it also contains 0s and -1s.
|
||||
# The last entry in each sublist is -1.
|
||||
# It axes is [utt][path][word_id]
|
||||
# Its axes are [utt][path][word_id]
|
||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user