mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Refactor decoding code.
This commit is contained in:
parent
00f8371f37
commit
6f9fe5b906
1
.github/workflows/test.yml
vendored
1
.github/workflows/test.yml
vendored
@ -55,6 +55,7 @@ jobs:
|
||||
git clone --depth 1 https://github.com/lhotse-speech/lhotse
|
||||
cd lhotse
|
||||
sed -i.bak "/torch/d" requirements.txt
|
||||
pip install -r ./requirements.txt
|
||||
|
||||
|
||||
- name: Run tests
|
||||
|
@ -13,6 +13,7 @@ from model import TdnnLstm
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
||||
from icefall.decode import get_lattice, nbest_decoding, one_best_decoding
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -48,7 +49,7 @@ def get_parser():
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("tdnn_lstm_ctc/exp3/"),
|
||||
"exp_dir": Path("tdnn_lstm_ctc/exp/"),
|
||||
"lang_dir": Path("data/lang"),
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 3,
|
||||
@ -56,6 +57,9 @@ def get_params() -> AttributeDict:
|
||||
"output_beam": 8,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
"method": "1best", # [1best, nbest]
|
||||
"num_paths": 30, # used when method is nbest
|
||||
}
|
||||
)
|
||||
return params
|
||||
@ -100,20 +104,28 @@ def decode_one_batch(
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
|
||||
|
||||
lattices = k2.intersect_dense_pruned(
|
||||
HLG,
|
||||
dense_fsa_vec,
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
HLG=HLG,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
)
|
||||
|
||||
best_paths = k2.shortest_path(lattices, use_double_scores=True)
|
||||
if params.method == "1best":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
else:
|
||||
best_path = nbest_decoding(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
|
||||
hyps = get_texts(best_paths)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
|
||||
|
||||
texts = supervisions["text"]
|
||||
|
245
icefall/decode.py
Normal file
245
icefall/decode.py
Normal file
@ -0,0 +1,245 @@
|
||||
import k2
|
||||
import torch
|
||||
|
||||
|
||||
def _intersect_device(
|
||||
a_fsas: k2.Fsa,
|
||||
b_fsas: k2.Fsa,
|
||||
b_to_a_map: torch.Tensor,
|
||||
sorted_match_a: bool,
|
||||
batch_size: int = 50,
|
||||
):
|
||||
"""This is a wrapper of k2.intersect_device and its purpose is to split
|
||||
b_fsas into several batches and process each batch separately to avoid
|
||||
CUDA OOM error.
|
||||
|
||||
The arguments and return value of this function are the same as
|
||||
k2.intersect_device.
|
||||
"""
|
||||
num_fsas = b_fsas.shape[0]
|
||||
if num_fsas <= batch_size:
|
||||
return k2.intersect_device(
|
||||
a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a
|
||||
)
|
||||
|
||||
num_batches = (num_fsas + batch_size - 1) // batch_size
|
||||
splits = []
|
||||
for i in range(num_batches):
|
||||
start = i * batch_size
|
||||
end = min(start + batch_size, num_fsas)
|
||||
splits.append((start, end))
|
||||
|
||||
ans = []
|
||||
for start, end in splits:
|
||||
indexes = torch.arange(start, end).to(b_to_a_map)
|
||||
|
||||
fsas = k2.index(b_fsas, indexes)
|
||||
b_to_a = k2.index(b_to_a_map, indexes)
|
||||
path_lattice = k2.intersect_device(
|
||||
a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a
|
||||
)
|
||||
ans.append(path_lattice)
|
||||
|
||||
return k2.cat(ans)
|
||||
|
||||
|
||||
def get_lattice(
|
||||
nnet_output: torch.Tensor,
|
||||
HLG: k2.Fsa,
|
||||
supervision_segments: torch.Tensor,
|
||||
search_beam: float,
|
||||
output_beam: float,
|
||||
min_active_states: int,
|
||||
max_active_states: int,
|
||||
):
|
||||
"""Get the decoding lattice from a decoding graph and neural
|
||||
network output.
|
||||
|
||||
Args:
|
||||
nnet_output:
|
||||
It is the output of a neural model of shape `[N, T, C]`.
|
||||
HLG:
|
||||
An Fsa, the decoding graph. See also `compile_HLG.py`.
|
||||
supervision_segments:
|
||||
A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns.
|
||||
Each row contains information for a supervision segment. Column 0
|
||||
is the `sequence_index` indicating which sequence this segment
|
||||
comes from; column 1 specifies the `start_frame` of this segment
|
||||
within the sequence; column 2 contains the `duration` of this
|
||||
segment.
|
||||
search_beam:
|
||||
Decoding beam, e.g. 20. Smaller is faster, larger is more exact
|
||||
(less pruning). This is the default value; it may be modified by
|
||||
`min_active_states` and `max_active_states`.
|
||||
output_beam:
|
||||
Beam to prune output, similar to lattice-beam in Kaldi. Relative
|
||||
to best path of output.
|
||||
min_active_states:
|
||||
Minimum number of FSA states that are allowed to be active on any given
|
||||
frame for any given intersection/composition task. This is advisory,
|
||||
in that it will try not to have fewer than this number active.
|
||||
Set it to zero if there is no constraint.
|
||||
max_active_states:
|
||||
Maximum number of FSA states that are allowed to be active on any given
|
||||
frame for any given intersection/composition task. This is advisory,
|
||||
in that it will try not to exceed that but may not always succeed.
|
||||
You can use a very large number if no constraint is needed.
|
||||
Returns:
|
||||
A lattice containing the decoding result.
|
||||
"""
|
||||
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
|
||||
|
||||
lattice = k2.intersect_dense_pruned(
|
||||
HLG,
|
||||
dense_fsa_vec,
|
||||
search_beam=search_beam,
|
||||
output_beam=output_beam,
|
||||
min_active_states=min_active_states,
|
||||
max_active_states=max_active_states,
|
||||
)
|
||||
|
||||
return lattice
|
||||
|
||||
|
||||
def one_best_decoding(
|
||||
lattice: k2.Fsa, use_double_scores: bool = True
|
||||
) -> k2.Fsa:
|
||||
"""Get the best path from a lattice.
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
The decoding lattice returned by :func:`get_lattice`.
|
||||
use_double_scores:
|
||||
True to use double precision floating point in the computation.
|
||||
False to use single precision.
|
||||
Return:
|
||||
An FsaVec containing linear paths.
|
||||
"""
|
||||
best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
|
||||
return best_path
|
||||
|
||||
|
||||
def nbest_decoding(
|
||||
lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True
|
||||
):
|
||||
"""It implements something like CTC prefix beam search using n-best lists.
|
||||
|
||||
The basic idea is to first extra n-best paths from the given lattice,
|
||||
build a word seqs from these paths, and compute the total scores
|
||||
of these sequences in the log-semiring. The one with the max score
|
||||
is used as the decoding output.
|
||||
|
||||
Caution:
|
||||
Don't be confused by `best` in the name `n-best`. Paths are selected
|
||||
randomly, not by ranking their scores.
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
The decoding lattice, returned by :func:`get_lattice`.
|
||||
num_paths:
|
||||
It specifies the size `n` in n-best. Note: Paths are selected randomly
|
||||
and those containing identical word sequences are remove dand only one
|
||||
of them is kept.
|
||||
use_double_scores:
|
||||
True to use double precision floating point in the computation.
|
||||
False to use single precision.
|
||||
Returns:
|
||||
An FsaVec containing linear FSAs.
|
||||
"""
|
||||
# First, extract `num_paths` paths for each sequence.
|
||||
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
|
||||
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
|
||||
|
||||
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
||||
# but it contains word IDs. Note that it also contains 0s and -1s.
|
||||
# The last entry in each sublist is -1.
|
||||
word_seq = k2.index(lattice.aux_labels, path)
|
||||
# Note: the above operation supports also the case when
|
||||
# lattice.aux_labels is a ragged tensor. In that case,
|
||||
# `remove_axis=True` is used inside the pybind11 binding code,
|
||||
# so the resulting `word_seq` still has 3 axes, like `path`.
|
||||
# The 3 axes are [seq][path][word_id]
|
||||
|
||||
# Remove 0 (epsilon) and -1 from word_seq
|
||||
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
|
||||
|
||||
# Remove sequences with identical word sequences.
|
||||
#
|
||||
# k2.ragged.unique_sequences will reorder paths within a seq.
|
||||
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
||||
# to the input path index.
|
||||
# new2old.numel() == unique_word_seqs.tot_size(1)
|
||||
unique_word_seq, _, new2old = k2.ragged.unique_sequences(
|
||||
word_seq, need_num_repeats=False, need_new2old_indexes=True
|
||||
)
|
||||
# Note: unique_word_seq still has the same axes as word_seq
|
||||
|
||||
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
|
||||
|
||||
# path_to_seq_map is a 1-D torch.Tensor.
|
||||
# path_to_seq_map[i] is the seq to which the i-th path belongs
|
||||
path_to_seq_map = seq_to_path_shape.row_ids(1)
|
||||
|
||||
# Remove the seq axis.
|
||||
# Now unique_word_seq has only two axes [path][word]
|
||||
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
|
||||
|
||||
# word_fsa is an FsaVec with axes [path][state][arc]
|
||||
word_fsa = k2.linear_fsa(unique_word_seq)
|
||||
|
||||
# add epsilon self loops since we will use
|
||||
# k2.intersect_device, which treats epsilon as a normal symbol
|
||||
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
|
||||
|
||||
# lattice has token IDs as labels and word IDs as aux_labels.
|
||||
# inv_lattice has word IDs as labels and token IDs as aux_labels
|
||||
inv_lattice = k2.invert(lattice)
|
||||
inv_lattice = k2.arc_sort(inv_lattice)
|
||||
|
||||
path_lattice = _intersect_device(
|
||||
inv_lattice,
|
||||
word_fsa_with_epsilon_loops,
|
||||
b_to_a_map=path_to_seq_map,
|
||||
sorted_match_a=True,
|
||||
)
|
||||
# path_lat has word IDs as labels and token IDs as aux_labels
|
||||
|
||||
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
||||
|
||||
tot_scores = path_lattice.get_tot_scores(
|
||||
use_double_scores=use_double_scores, log_semiring=False
|
||||
)
|
||||
|
||||
# RaggedFloat currently supports float32 only.
|
||||
# If Ragged<double> is wrapped, we can use k2.RaggedDouble here
|
||||
ragged_tot_scores = k2.RaggedFloat(
|
||||
seq_to_path_shape, tot_scores.to(torch.float32)
|
||||
)
|
||||
|
||||
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
|
||||
|
||||
# Since we invoked `k2.ragged.unique_sequences`, which reorders
|
||||
# the index from `path`, we use `new2old` here to convert argmax_indexes
|
||||
# to the indexes into `path`.
|
||||
#
|
||||
# Use k2.index here since argmax_indexes' dtype is torch.int32
|
||||
best_path_indexes = k2.index(new2old, argmax_indexes)
|
||||
|
||||
path_2axes = k2.ragged.remove_axis(path, 0)
|
||||
|
||||
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
|
||||
best_path = k2.index(path_2axes, best_path_indexes)
|
||||
|
||||
# labels is a k2.RaggedInt with 2 axes [path][token_id]
|
||||
# Note that it contains -1s.
|
||||
labels = k2.index(lattice.labels.contiguous(), best_path)
|
||||
|
||||
labels = k2.ragged.remove_values_eq(labels, -1)
|
||||
|
||||
# lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so
|
||||
# aux_labels is also a k2.RaggedInt with 2 axes
|
||||
aux_labels = k2.index(lattice.aux_labels, best_path.values())
|
||||
|
||||
best_path_fsa = k2.linear_fsa(labels)
|
||||
best_path_fsa.aux_labels = aux_labels
|
||||
return best_path_fsa
|
@ -139,7 +139,7 @@ class AttributeDict(dict):
|
||||
|
||||
|
||||
def encode_supervisions(
|
||||
supervisions: Dict[str, torch.Tensor], subsampling_factor: int
|
||||
supervisions: dict, subsampling_factor: int
|
||||
) -> Tuple[torch.Tensor, List[str]]:
|
||||
"""
|
||||
Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor,
|
||||
|
Loading…
x
Reference in New Issue
Block a user