Refactor decoding code.

This commit is contained in:
Fangjun Kuang 2021-07-24 22:23:50 +08:00
parent 00f8371f37
commit 6f9fe5b906
4 changed files with 267 additions and 9 deletions

View File

@ -55,6 +55,7 @@ jobs:
git clone --depth 1 https://github.com/lhotse-speech/lhotse
cd lhotse
sed -i.bak "/torch/d" requirements.txt
pip install -r ./requirements.txt
- name: Run tests

View File

@ -13,6 +13,7 @@ from model import TdnnLstm
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import get_lattice, nbest_decoding, one_best_decoding
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -48,7 +49,7 @@ def get_parser():
def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("tdnn_lstm_ctc/exp3/"),
"exp_dir": Path("tdnn_lstm_ctc/exp/"),
"lang_dir": Path("data/lang"),
"feature_dim": 80,
"subsampling_factor": 3,
@ -56,6 +57,9 @@ def get_params() -> AttributeDict:
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
"method": "1best", # [1best, nbest]
"num_paths": 30, # used when method is nbest
}
)
return params
@ -100,20 +104,28 @@ def decode_one_batch(
1,
).to(torch.int32)
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
lattices = k2.intersect_dense_pruned(
HLG,
dense_fsa_vec,
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
best_paths = k2.shortest_path(lattices, use_double_scores=True)
if params.method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
)
hyps = get_texts(best_paths)
hyps = get_texts(best_path)
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
texts = supervisions["text"]

245
icefall/decode.py Normal file
View File

@ -0,0 +1,245 @@
import k2
import torch
def _intersect_device(
a_fsas: k2.Fsa,
b_fsas: k2.Fsa,
b_to_a_map: torch.Tensor,
sorted_match_a: bool,
batch_size: int = 50,
):
"""This is a wrapper of k2.intersect_device and its purpose is to split
b_fsas into several batches and process each batch separately to avoid
CUDA OOM error.
The arguments and return value of this function are the same as
k2.intersect_device.
"""
num_fsas = b_fsas.shape[0]
if num_fsas <= batch_size:
return k2.intersect_device(
a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a
)
num_batches = (num_fsas + batch_size - 1) // batch_size
splits = []
for i in range(num_batches):
start = i * batch_size
end = min(start + batch_size, num_fsas)
splits.append((start, end))
ans = []
for start, end in splits:
indexes = torch.arange(start, end).to(b_to_a_map)
fsas = k2.index(b_fsas, indexes)
b_to_a = k2.index(b_to_a_map, indexes)
path_lattice = k2.intersect_device(
a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a
)
ans.append(path_lattice)
return k2.cat(ans)
def get_lattice(
nnet_output: torch.Tensor,
HLG: k2.Fsa,
supervision_segments: torch.Tensor,
search_beam: float,
output_beam: float,
min_active_states: int,
max_active_states: int,
):
"""Get the decoding lattice from a decoding graph and neural
network output.
Args:
nnet_output:
It is the output of a neural model of shape `[N, T, C]`.
HLG:
An Fsa, the decoding graph. See also `compile_HLG.py`.
supervision_segments:
A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns.
Each row contains information for a supervision segment. Column 0
is the `sequence_index` indicating which sequence this segment
comes from; column 1 specifies the `start_frame` of this segment
within the sequence; column 2 contains the `duration` of this
segment.
search_beam:
Decoding beam, e.g. 20. Smaller is faster, larger is more exact
(less pruning). This is the default value; it may be modified by
`min_active_states` and `max_active_states`.
output_beam:
Beam to prune output, similar to lattice-beam in Kaldi. Relative
to best path of output.
min_active_states:
Minimum number of FSA states that are allowed to be active on any given
frame for any given intersection/composition task. This is advisory,
in that it will try not to have fewer than this number active.
Set it to zero if there is no constraint.
max_active_states:
Maximum number of FSA states that are allowed to be active on any given
frame for any given intersection/composition task. This is advisory,
in that it will try not to exceed that but may not always succeed.
You can use a very large number if no constraint is needed.
Returns:
A lattice containing the decoding result.
"""
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
lattice = k2.intersect_dense_pruned(
HLG,
dense_fsa_vec,
search_beam=search_beam,
output_beam=output_beam,
min_active_states=min_active_states,
max_active_states=max_active_states,
)
return lattice
def one_best_decoding(
lattice: k2.Fsa, use_double_scores: bool = True
) -> k2.Fsa:
"""Get the best path from a lattice.
Args:
lattice:
The decoding lattice returned by :func:`get_lattice`.
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
Return:
An FsaVec containing linear paths.
"""
best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
return best_path
def nbest_decoding(
lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True
):
"""It implements something like CTC prefix beam search using n-best lists.
The basic idea is to first extra n-best paths from the given lattice,
build a word seqs from these paths, and compute the total scores
of these sequences in the log-semiring. The one with the max score
is used as the decoding output.
Caution:
Don't be confused by `best` in the name `n-best`. Paths are selected
randomly, not by ranking their scores.
Args:
lattice:
The decoding lattice, returned by :func:`get_lattice`.
num_paths:
It specifies the size `n` in n-best. Note: Paths are selected randomly
and those containing identical word sequences are remove dand only one
of them is kept.
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
Returns:
An FsaVec containing linear FSAs.
"""
# First, extract `num_paths` paths for each sequence.
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
# word_seq is a k2.RaggedInt sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1.
word_seq = k2.index(lattice.aux_labels, path)
# Note: the above operation supports also the case when
# lattice.aux_labels is a ragged tensor. In that case,
# `remove_axis=True` is used inside the pybind11 binding code,
# so the resulting `word_seq` still has 3 axes, like `path`.
# The 3 axes are [seq][path][word_id]
# Remove 0 (epsilon) and -1 from word_seq
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
# Remove sequences with identical word sequences.
#
# k2.ragged.unique_sequences will reorder paths within a seq.
# `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index.
# new2old.numel() == unique_word_seqs.tot_size(1)
unique_word_seq, _, new2old = k2.ragged.unique_sequences(
word_seq, need_num_repeats=False, need_new2old_indexes=True
)
# Note: unique_word_seq still has the same axes as word_seq
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
# path_to_seq_map is a 1-D torch.Tensor.
# path_to_seq_map[i] is the seq to which the i-th path belongs
path_to_seq_map = seq_to_path_shape.row_ids(1)
# Remove the seq axis.
# Now unique_word_seq has only two axes [path][word]
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
# word_fsa is an FsaVec with axes [path][state][arc]
word_fsa = k2.linear_fsa(unique_word_seq)
# add epsilon self loops since we will use
# k2.intersect_device, which treats epsilon as a normal symbol
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
path_lattice = _intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_seq_map,
sorted_match_a=True,
)
# path_lat has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
tot_scores = path_lattice.get_tot_scores(
use_double_scores=use_double_scores, log_semiring=False
)
# RaggedFloat currently supports float32 only.
# If Ragged<double> is wrapped, we can use k2.RaggedDouble here
ragged_tot_scores = k2.RaggedFloat(
seq_to_path_shape, tot_scores.to(torch.float32)
)
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
# Since we invoked `k2.ragged.unique_sequences`, which reorders
# the index from `path`, we use `new2old` here to convert argmax_indexes
# to the indexes into `path`.
#
# Use k2.index here since argmax_indexes' dtype is torch.int32
best_path_indexes = k2.index(new2old, argmax_indexes)
path_2axes = k2.ragged.remove_axis(path, 0)
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
best_path = k2.index(path_2axes, best_path_indexes)
# labels is a k2.RaggedInt with 2 axes [path][token_id]
# Note that it contains -1s.
labels = k2.index(lattice.labels.contiguous(), best_path)
labels = k2.ragged.remove_values_eq(labels, -1)
# lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so
# aux_labels is also a k2.RaggedInt with 2 axes
aux_labels = k2.index(lattice.aux_labels, best_path.values())
best_path_fsa = k2.linear_fsa(labels)
best_path_fsa.aux_labels = aux_labels
return best_path_fsa

View File

@ -139,7 +139,7 @@ class AttributeDict(dict):
def encode_supervisions(
supervisions: Dict[str, torch.Tensor], subsampling_factor: int
supervisions: dict, subsampling_factor: int
) -> Tuple[torch.Tensor, List[str]]:
"""
Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor,