mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
Compute the Nbest oracle WER for RNN-T decoding.
This commit is contained in:
parent
e9f0975868
commit
b1c3705fbe
@ -22,11 +22,11 @@ import k2
|
||||
import torch
|
||||
from model import Transducer
|
||||
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
def fast_beam_search(
|
||||
def fast_beam_search_one_best(
|
||||
model: Transducer,
|
||||
decoding_graph: k2.Fsa,
|
||||
encoder_out: torch.Tensor,
|
||||
@ -37,6 +37,9 @@ def fast_beam_search(
|
||||
) -> List[List[int]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
A lattice is first obtained using modified beam search, and then
|
||||
the shortest path within the lattice is used as the final output.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
@ -56,6 +59,153 @@ def fast_beam_search(
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=beam,
|
||||
max_states=max_states,
|
||||
max_contexts=max_contexts,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(lattice)
|
||||
hyps = get_texts(best_path)
|
||||
return hyps
|
||||
|
||||
|
||||
def fast_beam_search_nbest_oracle(
|
||||
model: Transducer,
|
||||
decoding_graph: k2.Fsa,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
beam: float,
|
||||
max_states: int,
|
||||
max_contexts: int,
|
||||
num_paths: int,
|
||||
ref_texts: List[List[int]],
|
||||
use_double_scores: bool = True,
|
||||
nbest_scale: float = 0.5,
|
||||
) -> List[List[int]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
A lattice is first obtained using modified beam search, and then
|
||||
we select `num_paths` linear paths from the lattice. The path
|
||||
that has the minimum edit distance with the given reference transcript
|
||||
is used as the output.
|
||||
|
||||
This is the best result we can achieve for any nbest based rescoring
|
||||
methods.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
decoding_graph:
|
||||
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder.
|
||||
encoder_out_lens:
|
||||
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||
before padding.
|
||||
beam:
|
||||
Beam value, similar to the beam used in Kaldi..
|
||||
max_states:
|
||||
Max states per stream per frame.
|
||||
max_contexts:
|
||||
Max contexts pre stream per frame.
|
||||
num_paths:
|
||||
Number of paths to extract from the decoded lattice.
|
||||
ref_texts:
|
||||
A list-of-list of integers containing the reference transcripts.
|
||||
If the decoding_graph is a trivial_graph, the integer ID is the
|
||||
BPE token ID.
|
||||
use_double_scores:
|
||||
True to use double precision for computation. False to use
|
||||
single precision.
|
||||
nbest_scale:
|
||||
It's the scale applied to the lattice.scores. A smaller value
|
||||
yields more unique paths.
|
||||
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=beam,
|
||||
max_states=max_states,
|
||||
max_contexts=max_contexts,
|
||||
)
|
||||
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=use_double_scores,
|
||||
nbest_scale=nbest_scale,
|
||||
)
|
||||
|
||||
# We assume the labels of nbest.fsa are token IDs and the aux_labels
|
||||
# are word IDs.
|
||||
word_fsa = k2.invert(nbest.fsa)
|
||||
word_ids = get_texts(word_fsa, return_ragged=True)
|
||||
|
||||
hyps = k2.levenshtein_graph(word_ids)
|
||||
refs = k2.levenshtein_graph(ref_texts, device=hyps.device)
|
||||
|
||||
levenshtein_alignment = k2.levenshtein_alignment(
|
||||
refs=refs,
|
||||
hyps=hyps,
|
||||
hyp_to_ref_map=nbest.shape.row_ids(1),
|
||||
sorted_match_ref=True,
|
||||
)
|
||||
|
||||
tot_scores = levenshtein_alignment.get_tot_scores(
|
||||
use_double_scores=False, log_semiring=False
|
||||
)
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
|
||||
max_indexes = ragged_tot_scores.argmax()
|
||||
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
return hyps
|
||||
|
||||
|
||||
def fast_beam_search(
|
||||
model: Transducer,
|
||||
decoding_graph: k2.Fsa,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
beam: float,
|
||||
max_states: int,
|
||||
max_contexts: int,
|
||||
) -> k2.Fsa:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
decoding_graph:
|
||||
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder.
|
||||
encoder_out_lens:
|
||||
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||
before padding.
|
||||
beam:
|
||||
Beam value, similar to the beam used in Kaldi..
|
||||
max_states:
|
||||
Max states per stream per frame.
|
||||
max_contexts:
|
||||
Max contexts pre stream per frame.
|
||||
Returns:
|
||||
Return an FsaVec with axes [utt][state][arc] containing the decoded
|
||||
lattice. Note: When the input graph is a TrivialGraph, the returned
|
||||
lattice is actually an acceptor.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
context_size = model.decoder.context_size
|
||||
@ -104,9 +254,7 @@ def fast_beam_search(
|
||||
decoding_streams.terminate_and_flush_to_streams()
|
||||
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
||||
|
||||
best_path = one_best_decoding(lattice)
|
||||
hyps = get_texts(best_path)
|
||||
return hyps
|
||||
return lattice
|
||||
|
||||
|
||||
def greedy_search(
|
||||
@ -131,6 +279,7 @@ def greedy_search(
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
|
||||
device = model.device
|
||||
|
||||
@ -171,7 +320,7 @@ def greedy_search(
|
||||
# logits is (1, 1, 1, vocab_size)
|
||||
|
||||
y = logits.argmax().item()
|
||||
if y != blank_id:
|
||||
if y not in (blank_id, unk_id):
|
||||
hyp.append(y)
|
||||
decoder_input = torch.tensor(
|
||||
[hyp[-context_size:]], device=device
|
||||
@ -212,6 +361,7 @@ def greedy_search_batch(
|
||||
T = encoder_out.size(1)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
hyps = [[blank_id] * context_size for _ in range(batch_size)]
|
||||
@ -240,7 +390,7 @@ def greedy_search_batch(
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
if v not in (blank_id, unk_id):
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
@ -433,6 +583,7 @@ def modified_beam_search(
|
||||
T = encoder_out.size(1)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
device = model.device
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
@ -515,7 +666,7 @@ def modified_beam_search(
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
if new_token != blank_id:
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
@ -556,6 +707,7 @@ def _deprecated_modified_beam_search(
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = model.device
|
||||
@ -626,7 +778,7 @@ def _deprecated_modified_beam_search(
|
||||
hyp = A[topk_hyp_indexes[i]]
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[i]
|
||||
if new_token != blank_id:
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_log_prob = topk_log_probs[i]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
@ -663,6 +815,7 @@ def beam_search(
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = model.device
|
||||
@ -748,7 +901,7 @@ def beam_search(
|
||||
# Second, process other non-blank labels
|
||||
values, indices = log_prob.topk(beam + 1)
|
||||
for i, v in zip(indices.tolist(), values.tolist()):
|
||||
if i == blank_id:
|
||||
if i in (blank_id, unk_id):
|
||||
continue
|
||||
new_ys = y_star.ys + [i]
|
||||
new_log_prob = y_star.log_prob + v
|
||||
|
@ -69,7 +69,8 @@ import torch.nn as nn
|
||||
from asr_datamodule import AsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
@ -145,6 +146,7 @@ def get_parser():
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
@ -164,7 +166,8 @@ def get_parser():
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
Used only when --decoding-method is
|
||||
fast_beam_search or fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -172,7 +175,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
fast_beam_search or fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -180,7 +183,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
fast_beam_search or fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -198,6 +201,23 @@ def get_parser():
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""Number of paths for computed nbest oracle WER
|
||||
when the decoding method is fast_beam_search_nbest_oracle.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding_method is fast_beam_search_nbest_oracle.
|
||||
""",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -231,7 +251,8 @@ def decode_one_batch(
|
||||
for the format of the `batch`.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
only when --decoding_method is
|
||||
fast_beam_search or fast_beam_search_nbest_oracle.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
@ -252,7 +273,7 @@ def decode_one_batch(
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search(
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -263,6 +284,21 @@ def decode_one_batch(
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
@ -316,6 +352,16 @@ def decode_one_batch(
|
||||
f"max_states_{params.max_states}"
|
||||
): hyps
|
||||
}
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
return {
|
||||
(
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}_"
|
||||
f"num_paths_{params.num_paths}_"
|
||||
f"nbest_scale_{params.nbest_scale}"
|
||||
): hyps
|
||||
}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
@ -450,15 +496,22 @@ def main():
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
else:
|
||||
@ -479,6 +532,7 @@ def main():
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.unk_id()
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
@ -506,8 +560,12 @@ def main():
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
model.unk_id = params.unk_id
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
if params.decoding_method in (
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
):
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user