Refactor decode.py to make it more readable and more modular.

This commit is contained in:
Fangjun Kuang 2021-09-13 20:38:51 +08:00
parent 24656e9749
commit 9ddf23636e
2 changed files with 282 additions and 0 deletions

223
icefall/decode2.py Normal file
View File

@ -0,0 +1,223 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: This file is a refactor of decode.py
# We will delete decode.py and rename this file to decode.py
import k2
import torch
class Nbest(object):
"""
An Nbest object contains two fields:
(1) fsa. It is an FsaVec containing a vector of **linear** FSAs.
(2) shape. Its type is :class:`k2.RaggedShape`.
The field `shape` has two axes [utt][path]. `shape.dim0` contains
the number of utterances, which is also the number of rows in the
supervision_segments. `shape.tot_size(1)` contains the number
of paths, which is also the number of FSAs in `fsa`.
Caution:
Don't be confused by the name `Nbest`. The best in the name `Nbest`
has nothing to do with the `best scores`. The important part is
`N` in `Nbest`, not `best`.
"""
def __init__(self, fsa: k2.Fsa, shape: k2.RaggedShape) -> None:
assert len(fsa.shape) == 3, f"fsa.shape: {fsa.shape}"
assert shape.num_axes == 2, f"num_axes: {shape.num_axes}"
if fsa.shape[0] != shape.tot_size(1):
raise ValueError(
f"{fsa.shape[0]} vs {shape.tot_size(1)}\n"
"Number of FSAs in `fsa` does not match the given shape"
)
self.fsa = fsa
self.shape = shape
def __str__(self):
s = "Nbest("
s += f"num_seqs:{self.shape.dim0}, "
s += f"num_fsas:{self.fsa.shape[0]})"
return s
@staticmethod
def from_lattice(
lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
scale: float = 0.5,
) -> "Nbest":
"""Construct an Nbest object by **sampling** `num_paths` from a lattice.
Args:
lattice:
An FsaVec with axes [utt][state][arc].
num_paths:
Number of paths to **sample** from the lattice
using :func:`k2.random_paths`.
use_double_scores:
True to use double precision in :func:`k2.random_paths`.
False to use single precision.
scale:
Scale `lattice.score` before passing it to :func:`k2.random_paths`.
A smaller value leads to more unique paths with the risk being not
to sample the path with the best score.
"""
saved_scores = lattice.scores.clone()
lattice.scores *= scale
# path is a ragged tensor with dtype torch.int32.
# It has three axes [utt][path][arc_pos
path = k2.random_paths(
lattice, num_paths=num_paths, use_double_scores=use_double_scores
)
lattice.scores = saved_scores
# word_seq is a k2.RaggedTensor sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1.
if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path, remove_axis=True)
# Each utterance has `num_paths` paths but some of them transduces
# to the same word sequence, so we need to remove repeated word
# sequences within an utterance. After removing repeats, each utterance
# contains different number of paths
#
# `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index.
_, _, new2old = word_seq.unique(
need_num_repeats=False, need_new2old_indexes=True
)
# kept_path is a ragged tensor with dtype torch.int32.
# It has axes [utt][path][arc_pos]
kept_path, _ = path.index(new2old, axis=1, need_value_indexes=False)
# utt_to_path_shape has axes [utt][path]
utt_to_path_shape = kept_path.shape.get_layer(0)
# Remove the utterance axis.
# Now kept_path has only two axes [path][arc_pos]
kept_path = kept_path.remove_axis(0)
# labels is a ragged tensor with 2 axes [path][token_id]
# Note that it contains -1s.
labels = k2.ragged.index(lattice.labels.contiguous(), kept_path)
# Remove -1 from labels as we will use it to construct a linear FSA
labels = labels.remove_values_eq(-1)
if isinstance(lattice.aux_labels, k2.RaggedTensor):
# lattice.aux_labels is a ragged tensor with dtype torch.int32.
# It has 2 axes [arc][word], so aux_labels is also a ragged tensor
# with 2 axes [arc][word]
aux_labels, _ = lattice.aux_labels.index(
indexes=kept_path.data, axis=0, need_value_indexes=False
)
else:
assert isinstance(lattice.aux_labels, torch.Tensor)
aux_labels = k2.index_select(lattice.aux_labels, kept_path.data)
# aux_labels is a 1-D torch.Tensor. It also contains -1 and 0.
fsa = k2.linear_fsa(labels)
fsa.aux_labels = aux_labels
return Nbest(fsa=fsa, shape=utt_to_path_shape)
def intersect(self, lattice: k2.Fsa, use_double_scores=True) -> "Nbest":
"""Intersect this Nbest object with a lattice and get 1-best
path from the resulting FsaVec.
Caution:
We assume FSAs in `self.fsa` don't have epsilon self-loops.
We also assume `self.fsa.labels` and `lattice.labels` are token IDs.
Args:
lattice:
An FsaVec with axes [utt][state][arc]
use_double_scores:
True to use double precision when computing shortest path.
False to use single precision.
Returns:
Return a new Nbest. This new Nbest shares the same shape with `self`,
while its `fsa` is the 1-best path from intersecting `self.fsa` and
`lattice`.
"""
assert (
self.fsa.device == lattice.device
), f"{self.fsa.device} vs {lattice.device}"
assert len(lattice.shape) == 3, f"{lattice.shape}"
assert (
lattice.arcs.dim0() == self.shape.dim0
), f"{lattice.arcs.dim0()} vs {self.shape.dim0}"
# We use a word fsa to intersect with k2.invert(lattice)
word_fsa = k2.invert(self.fsa)
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
path_to_utt_map = self.shape.row_ids(1)
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_utt_map,
sorted_match_a=True,
)
# path_lattice has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
one_best = k2.shortest_path(
path_lattice, use_double_scores=use_double_scores
)
one_best = k2.remove_epsilon(one_best)
return Nbest(fsa=one_best, shape=self.shape)
def tot_scores(self) -> k2.RaggedTensor:
"""Get total scores of the FSAs in this Nbest.
Note:
Since FSAs in Nbest are just linear FSAs, log-semirng and tropical
semiring produce the same total scores.
Returns:
Return a ragged tensor with two axes [utt][path_scores].
"""
# Use single precision since there are only additions.
scores = self.fsa.get_tot_scores(
use_double_scores=False, log_semiring=False
)
return k2.RaggedTensor(self.shape, scores)

59
test/test_decode.py Normal file
View File

@ -0,0 +1,59 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
You can run this file in one of the two ways:
(1) cd icefall; pytest test/test_decode.py
(2) cd icefall; ./test/test_decode.py
"""
import k2
from icefall.decode2 import Nbest
def test_nbest_from_lattice():
s = """
0 1 1 10 0.1
0 1 5 10 0.11
0 1 2 20 0.2
1 2 3 30 0.3
1 2 4 40 0.4
2 3 -1 -1 0.5
3
"""
lattice = k2.Fsa.from_str(s, acceptor=False)
lattice = k2.Fsa.from_fsas([lattice, lattice])
nbest = Nbest.from_lattice(
lattice=lattice, num_paths=10, use_double_scores=True, scale=0.5
)
# each lattice has only 4 distinct paths that have different word sequences:
# 10->30
# 10->40
# 20->30
# 20->40
#
# So there should be only 4 paths for each lattice in the Nbest object
assert nbest.fsa.shape[0] == 4 * 2
assert nbest.shape.row_splits(1).tolist() == [0, 4, 8]
nbest2 = nbest.intersect(lattice)
tot_scores = nbest2.tot_scores()
argmax = tot_scores.argmax()
best_path = k2.index_fsa(nbest2.fsa, argmax)
print(best_path[0])