mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Copy the files related to multi round nbest rescoring from k2 & snowfall
This commit is contained in:
parent
cf8d76293d
commit
cabe8b625b
264
icefall/nbest.py
Normal file
264
icefall/nbest.py
Normal file
@ -0,0 +1,264 @@
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
# This file implements the ideas proposed by Daniel Povey.
|
||||
#
|
||||
# See https://github.com/k2-fsa/snowfall/issues/232 for more details
|
||||
#
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import _k2
|
||||
import k2
|
||||
|
||||
# Note: We use `utterance` and `sequence` interchangeably in the comment
|
||||
|
||||
|
||||
class Nbest(object):
|
||||
'''
|
||||
An Nbest object contains two fields:
|
||||
|
||||
(1) fsa, its type is k2.Fsa
|
||||
(2) shape, its type is k2.RaggedShape (alias to _k2.RaggedShape)
|
||||
|
||||
The field `fsa` is an FsaVec containing a vector of **linear** FSAs.
|
||||
|
||||
The field `shape` has two axes [utt][path]. `shape.dim0()` contains
|
||||
the number of utterances, which is also the number of rows in the
|
||||
supervision_segments. `shape.tot_size(1)` contains the number
|
||||
of paths, which is also the number of FSAs in `fsa`.
|
||||
'''
|
||||
|
||||
def __init__(self, fsa: k2.Fsa, shape: _k2.RaggedShape) -> None:
|
||||
assert len(fsa.shape) == 3, f'fsa.shape: {fsa.shape}'
|
||||
assert shape.num_axes() == 2, f'num_axes: {shape.num_axes()}'
|
||||
|
||||
assert fsa.shape[0] == shape.tot_size(1), \
|
||||
f'{fsa.shape[0]} vs {shape.tot_size(1)}'
|
||||
|
||||
self.fsa = fsa
|
||||
self.shape = shape
|
||||
|
||||
def __str__(self):
|
||||
s = 'Nbest('
|
||||
s += f'num_seqs:{self.shape.dim0()}, '
|
||||
s += f'num_fsas:{self.fsa.shape[0]})'
|
||||
return s
|
||||
|
||||
def intersect(self, lats: k2.Fsa) -> 'Nbest':
|
||||
'''Intersect this Nbest object with a lattice and get 1-best
|
||||
path from the resulting FsaVec.
|
||||
|
||||
Caution:
|
||||
We assume FSAs in `self.fsa` don't have epsilon self-loops.
|
||||
We also assume `self.fsa.labels` and `lats.labels` are token IDs.
|
||||
|
||||
Args:
|
||||
lats:
|
||||
An FsaVec. It can be the return value of
|
||||
:func:`whole_lattice_rescoring`.
|
||||
Returns:
|
||||
Return a new Nbest. This new Nbest shares the same shape with `self`,
|
||||
while its `fsa` is the 1-best path from intersecting `self.fsa` and
|
||||
`lats.
|
||||
'''
|
||||
assert self.fsa.device == lats.device, \
|
||||
f'{self.fsa.device} vs {lats.device}'
|
||||
assert len(lats.shape) == 3, f'{lats.shape}'
|
||||
assert lats.arcs.dim0() == self.shape.dim0(), \
|
||||
f'{lats.arcs.dim0()} vs {self.shape.dim0()}'
|
||||
|
||||
lats = k2.arc_sort(lats) # no-op if lats is already arc sorted
|
||||
|
||||
fsas_with_epsilon_loops = k2.add_epsilon_self_loops(self.fsa)
|
||||
|
||||
path_to_seq_map = self.shape.row_ids(1)
|
||||
|
||||
ans_lats = k2.intersect_device(a_fsas=lats,
|
||||
b_fsas=fsas_with_epsilon_loops,
|
||||
b_to_a_map=path_to_seq_map,
|
||||
sorted_match_a=True)
|
||||
|
||||
one_best = k2.shortest_path(ans_lats, use_double_scores=True)
|
||||
|
||||
one_best = k2.remove_epsilon(one_best)
|
||||
|
||||
return Nbest(fsa=one_best, shape=self.shape)
|
||||
|
||||
def total_scores(self) -> _k2.RaggedFloat:
|
||||
'''Get total scores of the FSAs in this Nbest.
|
||||
|
||||
Note:
|
||||
Since FSAs in Nbest are just linear FSAs, log-semirng and tropical
|
||||
semiring produce the same total scores.
|
||||
|
||||
Returns:
|
||||
Return a ragged tensor with two axes [utt][path_scores].
|
||||
'''
|
||||
scores = self.fsa.get_tot_scores(use_double_scores=True,
|
||||
log_semiring=False)
|
||||
# We use single precision here since we only wrap k2.RaggedFloat.
|
||||
# If k2.RaggedDouble is wrapped, we can use double precision here.
|
||||
return _k2.RaggedFloat(self.shape, scores.float())
|
||||
|
||||
def top_k(self, k: int) -> 'Nbest':
|
||||
'''Get a subset of paths in the Nbest. The resulting Nbest is regular
|
||||
in that each sequence (i.e., utterance) has the same number of
|
||||
paths (k).
|
||||
|
||||
We select the top-k paths according to the total_scores of each path.
|
||||
If a utterance has less than k paths, then its last path, after sorting
|
||||
by tot_scores in descending order, is repeated so that each utterance
|
||||
has exactly k paths.
|
||||
|
||||
Args:
|
||||
k:
|
||||
Number of paths in each utterance.
|
||||
Returns:
|
||||
Return a new Nbest with a regular shape.
|
||||
'''
|
||||
ragged_scores = self.total_scores()
|
||||
|
||||
# indexes contains idx01's for self.shape
|
||||
# ragged_scores.values()[indexes] is sorted
|
||||
indexes = k2.ragged.sort_sublist(ragged_scores,
|
||||
descending=True,
|
||||
need_new2old_indexes=True)
|
||||
|
||||
ragged_indexes = k2.RaggedInt(self.shape, indexes)
|
||||
|
||||
padded_indexes = k2.ragged.pad(ragged_indexes,
|
||||
mode='replicate',
|
||||
value=-1)
|
||||
assert torch.ge(padded_indexes, 0).all(), \
|
||||
'Some utterances contain empty ' \
|
||||
f'n-best: {self.shape.row_splits(1)}'
|
||||
|
||||
# Select the idx01's of top-k paths of each utterance
|
||||
top_k_indexes = padded_indexes[:, :k].flatten().contiguous()
|
||||
|
||||
top_k_fsas = k2.index_fsa(self.fsa, top_k_indexes)
|
||||
|
||||
top_k_shape = k2.ragged.regular_ragged_shape(dim0=self.shape.dim0(),
|
||||
dim1=k)
|
||||
return Nbest(top_k_fsas, top_k_shape)
|
||||
|
||||
|
||||
def whole_lattice_rescoring(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa) -> k2.Fsa:
|
||||
'''Rescore the 1st pass lattice with an LM.
|
||||
|
||||
In general, the G in HLG used to obtain `lats` is a 3-gram LM.
|
||||
This function replaces the 3-gram LM in `lats` with a 4-gram LM.
|
||||
|
||||
Args:
|
||||
lats:
|
||||
The decoding lattice from the 1st pass. We assume it is the result
|
||||
of intersecting HLG with the network output.
|
||||
G_with_epsilon_loops:
|
||||
An LM. It is usually a 4-gram LM with epsilon self-loops.
|
||||
It should be arc sorted.
|
||||
Returns:
|
||||
Return a new lattice rescored with a given G.
|
||||
'''
|
||||
assert len(lats.shape) == 3, f'{lats.shape}'
|
||||
assert hasattr(lats, 'lm_scores')
|
||||
assert G_with_epsilon_loops.shape == (1, None, None), \
|
||||
f'{G_with_epsilon_loops.shape}'
|
||||
|
||||
device = lats.device
|
||||
lats.scores = lats.scores - lats.lm_scores
|
||||
# Now lats contains only acoustic scores
|
||||
|
||||
# We will use lm_scores from the given G, so remove lats.lm_scores here
|
||||
del lats.lm_scores
|
||||
assert hasattr(lats, 'lm_scores') is False
|
||||
|
||||
# inverted_lats has word IDs as labels.
|
||||
# Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt
|
||||
# if lats.aux_labels is a ragged tensor
|
||||
inverted_lats = k2.invert(lats)
|
||||
num_seqs = lats.shape[0]
|
||||
|
||||
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
|
||||
|
||||
while True:
|
||||
try:
|
||||
rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
|
||||
inverted_lats,
|
||||
b_to_a_map,
|
||||
sorted_match_a=True)
|
||||
break
|
||||
except RuntimeError as e:
|
||||
logging.info(f'Caught exception:\n{e}\n')
|
||||
# Usually, this is an OOM exception. We reduce
|
||||
# the size of the lattice and redo k2.intersect_device()
|
||||
|
||||
# NOTE(fangjun): The choice of the threshold 1e-5 is arbitrary here
|
||||
# to avoid OOM. We may need to fine tune it.
|
||||
logging.info(f'num_arcs before: {inverted_lats.num_arcs}')
|
||||
inverted_lats = k2.prune_on_arc_post(inverted_lats, 1e-5, True)
|
||||
logging.info(f'num_arcs after: {inverted_lats.num_arcs}')
|
||||
|
||||
rescoring_lats = k2.top_sort(k2.connect(rescoring_lats))
|
||||
|
||||
# inv_rescoring_lats has token IDs as labels
|
||||
# and word IDs as aux_labels.
|
||||
inv_rescoring_lats = k2.invert(rescoring_lats)
|
||||
return inv_rescoring_lats
|
||||
|
||||
|
||||
def generate_nbest_list(lats: k2.Fsa, num_paths: int) -> Nbest:
|
||||
'''Generate an n-best list from a lattice.
|
||||
|
||||
Args:
|
||||
lats:
|
||||
The decoding lattice from the first pass after LM rescoring.
|
||||
lats is an FsaVec. It can be the return value of
|
||||
:func:`whole_lattice_rescoring`
|
||||
num_paths:
|
||||
Size of n for n-best list. CAUTION: After removing paths
|
||||
that represent the same token sequences, the number of paths
|
||||
in different sequences may not be equal.
|
||||
Return:
|
||||
Return an Nbest object. Note the returned FSAs don't have epsilon
|
||||
self-loops.
|
||||
'''
|
||||
assert len(lats.shape) == 3
|
||||
|
||||
# CAUTION: We use `phones` instead of `tokens` here because
|
||||
# :func:`compile_HLG` uses `phones`
|
||||
#
|
||||
# Note: compile_HLG is from k2-fsa/snowfall
|
||||
assert hasattr(lats, 'phones')
|
||||
|
||||
assert not hasattr(lats, 'tokens')
|
||||
lats.tokens = lats.phones
|
||||
# we use tokens instead of phones in the following code
|
||||
|
||||
# First, extract `num_paths` paths for each sequence.
|
||||
# paths is a k2.RaggedInt with axes [seq][path][arc_pos]
|
||||
paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
|
||||
|
||||
# token_seqs is a k2.RaggedInt sharing the same shape as `paths`
|
||||
# but it contains token IDs. Note that it also contains 0s and -1s.
|
||||
# The last entry in each sublist is -1.
|
||||
# Its axes are [seq][path][token_id]
|
||||
token_seqs = k2.index(lats.tokens, paths)
|
||||
|
||||
# Remove epsilons (0s) and -1 from token_seqs
|
||||
token_seqs = k2.ragged.remove_values_leq(token_seqs, 0)
|
||||
|
||||
# unique_token_seqs is still a k2.RaggedInt with axes [seq][path]token_id].
|
||||
# But then number of pathsin each sequence may be different.
|
||||
unique_token_seqs, _, _ = k2.ragged.unique_sequences(
|
||||
token_seqs, need_num_repeats=False, need_new2old_indexes=False)
|
||||
|
||||
seq_to_path_shape = k2.ragged.get_layer(unique_token_seqs.shape(), 0)
|
||||
|
||||
# Remove the seq axis.
|
||||
# Now unique_token_seqs has only two axes [path][token_id]
|
||||
unique_token_seqs = k2.ragged.remove_axis(unique_token_seqs, 0)
|
||||
|
||||
token_fsas = k2.linear_fsa(unique_token_seqs)
|
||||
|
||||
return Nbest(fsa=token_fsas, shape=seq_to_path_shape)
|
@ -5,6 +5,7 @@ import subprocess
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from nbest import Nbest
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||
|
||||
@ -381,3 +382,75 @@ def write_error_stats(
|
||||
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
return float(tot_err_rate)
|
||||
|
||||
|
||||
def get_best_matching_stats(keys: Nbest, queries: Nbest,
|
||||
max_order: int) -> torch.Tensor:
|
||||
'''Get best matching stats on query positions.
|
||||
|
||||
Args:
|
||||
keys:
|
||||
The nbest after doing second pass rescoring.
|
||||
queries:
|
||||
Another nbest before doing second pass rescoring.
|
||||
max_order:
|
||||
The maximum n-gram order to ever return by `k2.get_best_matching_stats`
|
||||
|
||||
Returns:
|
||||
A tensor with the shape of [queries.fsa.num_elements, 5], each row
|
||||
contains the stats (init_score, mean, var, counts_out, ngram_order)
|
||||
of the token in the correspodding position in queries.
|
||||
'''
|
||||
assert keys.shape.dim0() == queries.shape.dim0(), \
|
||||
f'Utterances number in keys and queries should be equal : \
|
||||
{keys.shape.dim0()} vs {queries.shape.dim0()}'
|
||||
|
||||
# keys_tokens_shape [utt][path][token]
|
||||
keys_tokens_shape = k2.ragged.compose_ragged_shapes(keys.shape,
|
||||
k2.ragged.remove_axis(keys.fsa.arcs.shape(), 1))
|
||||
# queries_tokens_shape [utt][path][token]
|
||||
queries_tokens_shape = k2.ragged.compose_ragged_shapes(queries.shape,
|
||||
k2.ragged.remove_axis(queries.fsa.arcs.shape(), 1))
|
||||
|
||||
keys_tokens = k2.RaggedInt(keys_tokens_shape, keys.fsa.labels.clone())
|
||||
queries_tokens = k2.RaggedInt(queries_tokens_shape,
|
||||
queries.fsa.labels.clone())
|
||||
# tokens shape [utt][path][token]
|
||||
tokens = k2.ragged.cat([keys_tokens, queries_tokens], axis=1)
|
||||
|
||||
keys_token_num = keys.fsa.labels.size()[0]
|
||||
queries_tokens_num = queries.fsa.labels.size()[0]
|
||||
# counts on key positions are ones
|
||||
keys_counts = k2.RaggedInt(keys_tokens_shape,
|
||||
torch.ones(keys_token_num,
|
||||
dtype=torch.int32))
|
||||
# counts on query positions are zeros
|
||||
queries_counts = k2.RaggedInt(queries_tokens_shape,
|
||||
torch.zeros(queries_tokens_num,
|
||||
dtype=torch.int32))
|
||||
counts = k2.ragged.cat([keys_counts, queries_counts], axis=1).values()
|
||||
|
||||
# scores on key positions are the scores inherted from nbest path
|
||||
keys_scores = k2.RaggedFloat(keys_tokens_shape, keys.fsa.scores.clone())
|
||||
# scores on query positions MUST be zeros
|
||||
queries_scores = k2.RaggedFloat(queries_tokens_shape,
|
||||
torch.zeros(queries_tokens_num,
|
||||
dtype=torch.float32))
|
||||
scores = k2.ragged.cat([keys_scores, queries_scores], axis=1).values()
|
||||
|
||||
# we didn't remove -1 labels before
|
||||
min_token = -1
|
||||
eos = -1
|
||||
max_token = torch.max(torch.max(keys.fsa.labels),
|
||||
torch.max(queries.fsa.labels))
|
||||
mean, var, counts_out, ngram = k2.get_best_matching_stats(tokens, scores,
|
||||
counts, eos, min_token, max_token, max_order)
|
||||
|
||||
queries_init_scores = queries.fsa.scores.clone()
|
||||
# only return the stats on query positions
|
||||
masking = counts == 0
|
||||
# shape [queries_tokens_num, 5]
|
||||
return torch.transpose(torch.stack((queries_init_scores, mean[masking],
|
||||
var[masking], counts_out[masking],
|
||||
ngram[masking])), 0, 1)
|
||||
|
||||
|
110
test/test_nbest.py
Normal file
110
test/test_nbest.py
Normal file
@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# To run this single test, use
|
||||
#
|
||||
# ctest --verbose -R nbest_test_py
|
||||
|
||||
import unittest
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.nbest import Nbest
|
||||
|
||||
|
||||
class TestNbest(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.devices = [torch.device('cpu')]
|
||||
if torch.cuda.is_available() and k2.with_cuda:
|
||||
cls.devices.append(torch.device('cuda', 0))
|
||||
if torch.cuda.device_count() > 1:
|
||||
torch.cuda.set_device(1)
|
||||
cls.devices.append(torch.device('cuda', 1))
|
||||
|
||||
def test_nbest_constructor(self):
|
||||
fsa = k2.Fsa.from_str('''
|
||||
0 1 -1 0.1
|
||||
1
|
||||
''')
|
||||
|
||||
fsa_vec = k2.create_fsa_vec([fsa, fsa, fsa])
|
||||
shape = k2.RaggedShape('[[x x] [x]]')
|
||||
Nbest(fsa_vec, shape)
|
||||
|
||||
def test_top_k(self):
|
||||
fsa0 = k2.Fsa.from_str('''
|
||||
0 1 -1 0
|
||||
1
|
||||
''')
|
||||
fsas = [fsa0.clone() for i in range(10)]
|
||||
fsa_vec = k2.create_fsa_vec(fsas)
|
||||
fsa_vec.scores = torch.tensor([3, 0, 1, 5, 4, 2, 8, 1, 9, 6],
|
||||
dtype=torch.float)
|
||||
# 0 1 2 3 4 5 6 7 8 9
|
||||
# [ [3 0] [1 5 4] [2 8 1 9 6]
|
||||
shape = k2.RaggedShape('[ [x x] [x x x] [x x x x x] ]')
|
||||
nbest = Nbest(fsa_vec, shape)
|
||||
|
||||
# top_k: k is 1
|
||||
nbest1 = nbest.top_k(1)
|
||||
expected_fsa = k2.create_fsa_vec([fsa_vec[0], fsa_vec[3], fsa_vec[8]])
|
||||
assert str(nbest1.fsa) == str(expected_fsa)
|
||||
|
||||
expected_shape = k2.RaggedShape('[ [x] [x] [x] ]')
|
||||
assert nbest1.shape == expected_shape
|
||||
|
||||
# top_k: k is 2
|
||||
nbest2 = nbest.top_k(2)
|
||||
expected_fsa = k2.create_fsa_vec([
|
||||
fsa_vec[0], fsa_vec[1], fsa_vec[3], fsa_vec[4], fsa_vec[8],
|
||||
fsa_vec[6]
|
||||
])
|
||||
assert str(nbest2.fsa) == str(expected_fsa)
|
||||
|
||||
expected_shape = k2.RaggedShape('[ [x x] [x x] [x x] ]')
|
||||
assert nbest2.shape == expected_shape
|
||||
|
||||
# top_k: k is 3
|
||||
nbest3 = nbest.top_k(3)
|
||||
expected_fsa = k2.create_fsa_vec([
|
||||
fsa_vec[0], fsa_vec[1], fsa_vec[1], fsa_vec[3], fsa_vec[4],
|
||||
fsa_vec[2], fsa_vec[8], fsa_vec[6], fsa_vec[9]
|
||||
])
|
||||
assert str(nbest3.fsa) == str(expected_fsa)
|
||||
|
||||
expected_shape = k2.RaggedShape('[ [x x x] [x x x] [x x x] ]')
|
||||
assert nbest3.shape == expected_shape
|
||||
|
||||
# top_k: k is 4
|
||||
nbest4 = nbest.top_k(4)
|
||||
expected_fsa = k2.create_fsa_vec([
|
||||
fsa_vec[0], fsa_vec[1], fsa_vec[1], fsa_vec[1], fsa_vec[3],
|
||||
fsa_vec[4], fsa_vec[2], fsa_vec[2], fsa_vec[8], fsa_vec[6],
|
||||
fsa_vec[9], fsa_vec[5]
|
||||
])
|
||||
assert str(nbest4.fsa) == str(expected_fsa)
|
||||
|
||||
expected_shape = k2.RaggedShape('[ [x x x x] [x x x x] [x x x x] ]')
|
||||
assert nbest4.shape == expected_shape
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user