Copy the files related to multi round nbest rescoring from k2 & snowfall

This commit is contained in:
pkufool 2021-08-04 14:27:11 +08:00
parent cf8d76293d
commit cabe8b625b
3 changed files with 447 additions and 0 deletions

264
icefall/nbest.py Normal file
View File

@ -0,0 +1,264 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
# This file implements the ideas proposed by Daniel Povey.
#
# See https://github.com/k2-fsa/snowfall/issues/232 for more details
#
import logging
from typing import List
import torch
import _k2
import k2
# Note: We use `utterance` and `sequence` interchangeably in the comment
class Nbest(object):
'''
An Nbest object contains two fields:
(1) fsa, its type is k2.Fsa
(2) shape, its type is k2.RaggedShape (alias to _k2.RaggedShape)
The field `fsa` is an FsaVec containing a vector of **linear** FSAs.
The field `shape` has two axes [utt][path]. `shape.dim0()` contains
the number of utterances, which is also the number of rows in the
supervision_segments. `shape.tot_size(1)` contains the number
of paths, which is also the number of FSAs in `fsa`.
'''
def __init__(self, fsa: k2.Fsa, shape: _k2.RaggedShape) -> None:
assert len(fsa.shape) == 3, f'fsa.shape: {fsa.shape}'
assert shape.num_axes() == 2, f'num_axes: {shape.num_axes()}'
assert fsa.shape[0] == shape.tot_size(1), \
f'{fsa.shape[0]} vs {shape.tot_size(1)}'
self.fsa = fsa
self.shape = shape
def __str__(self):
s = 'Nbest('
s += f'num_seqs:{self.shape.dim0()}, '
s += f'num_fsas:{self.fsa.shape[0]})'
return s
def intersect(self, lats: k2.Fsa) -> 'Nbest':
'''Intersect this Nbest object with a lattice and get 1-best
path from the resulting FsaVec.
Caution:
We assume FSAs in `self.fsa` don't have epsilon self-loops.
We also assume `self.fsa.labels` and `lats.labels` are token IDs.
Args:
lats:
An FsaVec. It can be the return value of
:func:`whole_lattice_rescoring`.
Returns:
Return a new Nbest. This new Nbest shares the same shape with `self`,
while its `fsa` is the 1-best path from intersecting `self.fsa` and
`lats.
'''
assert self.fsa.device == lats.device, \
f'{self.fsa.device} vs {lats.device}'
assert len(lats.shape) == 3, f'{lats.shape}'
assert lats.arcs.dim0() == self.shape.dim0(), \
f'{lats.arcs.dim0()} vs {self.shape.dim0()}'
lats = k2.arc_sort(lats) # no-op if lats is already arc sorted
fsas_with_epsilon_loops = k2.add_epsilon_self_loops(self.fsa)
path_to_seq_map = self.shape.row_ids(1)
ans_lats = k2.intersect_device(a_fsas=lats,
b_fsas=fsas_with_epsilon_loops,
b_to_a_map=path_to_seq_map,
sorted_match_a=True)
one_best = k2.shortest_path(ans_lats, use_double_scores=True)
one_best = k2.remove_epsilon(one_best)
return Nbest(fsa=one_best, shape=self.shape)
def total_scores(self) -> _k2.RaggedFloat:
'''Get total scores of the FSAs in this Nbest.
Note:
Since FSAs in Nbest are just linear FSAs, log-semirng and tropical
semiring produce the same total scores.
Returns:
Return a ragged tensor with two axes [utt][path_scores].
'''
scores = self.fsa.get_tot_scores(use_double_scores=True,
log_semiring=False)
# We use single precision here since we only wrap k2.RaggedFloat.
# If k2.RaggedDouble is wrapped, we can use double precision here.
return _k2.RaggedFloat(self.shape, scores.float())
def top_k(self, k: int) -> 'Nbest':
'''Get a subset of paths in the Nbest. The resulting Nbest is regular
in that each sequence (i.e., utterance) has the same number of
paths (k).
We select the top-k paths according to the total_scores of each path.
If a utterance has less than k paths, then its last path, after sorting
by tot_scores in descending order, is repeated so that each utterance
has exactly k paths.
Args:
k:
Number of paths in each utterance.
Returns:
Return a new Nbest with a regular shape.
'''
ragged_scores = self.total_scores()
# indexes contains idx01's for self.shape
# ragged_scores.values()[indexes] is sorted
indexes = k2.ragged.sort_sublist(ragged_scores,
descending=True,
need_new2old_indexes=True)
ragged_indexes = k2.RaggedInt(self.shape, indexes)
padded_indexes = k2.ragged.pad(ragged_indexes,
mode='replicate',
value=-1)
assert torch.ge(padded_indexes, 0).all(), \
'Some utterances contain empty ' \
f'n-best: {self.shape.row_splits(1)}'
# Select the idx01's of top-k paths of each utterance
top_k_indexes = padded_indexes[:, :k].flatten().contiguous()
top_k_fsas = k2.index_fsa(self.fsa, top_k_indexes)
top_k_shape = k2.ragged.regular_ragged_shape(dim0=self.shape.dim0(),
dim1=k)
return Nbest(top_k_fsas, top_k_shape)
def whole_lattice_rescoring(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa) -> k2.Fsa:
'''Rescore the 1st pass lattice with an LM.
In general, the G in HLG used to obtain `lats` is a 3-gram LM.
This function replaces the 3-gram LM in `lats` with a 4-gram LM.
Args:
lats:
The decoding lattice from the 1st pass. We assume it is the result
of intersecting HLG with the network output.
G_with_epsilon_loops:
An LM. It is usually a 4-gram LM with epsilon self-loops.
It should be arc sorted.
Returns:
Return a new lattice rescored with a given G.
'''
assert len(lats.shape) == 3, f'{lats.shape}'
assert hasattr(lats, 'lm_scores')
assert G_with_epsilon_loops.shape == (1, None, None), \
f'{G_with_epsilon_loops.shape}'
device = lats.device
lats.scores = lats.scores - lats.lm_scores
# Now lats contains only acoustic scores
# We will use lm_scores from the given G, so remove lats.lm_scores here
del lats.lm_scores
assert hasattr(lats, 'lm_scores') is False
# inverted_lats has word IDs as labels.
# Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt
# if lats.aux_labels is a ragged tensor
inverted_lats = k2.invert(lats)
num_seqs = lats.shape[0]
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
while True:
try:
rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
inverted_lats,
b_to_a_map,
sorted_match_a=True)
break
except RuntimeError as e:
logging.info(f'Caught exception:\n{e}\n')
# Usually, this is an OOM exception. We reduce
# the size of the lattice and redo k2.intersect_device()
# NOTE(fangjun): The choice of the threshold 1e-5 is arbitrary here
# to avoid OOM. We may need to fine tune it.
logging.info(f'num_arcs before: {inverted_lats.num_arcs}')
inverted_lats = k2.prune_on_arc_post(inverted_lats, 1e-5, True)
logging.info(f'num_arcs after: {inverted_lats.num_arcs}')
rescoring_lats = k2.top_sort(k2.connect(rescoring_lats))
# inv_rescoring_lats has token IDs as labels
# and word IDs as aux_labels.
inv_rescoring_lats = k2.invert(rescoring_lats)
return inv_rescoring_lats
def generate_nbest_list(lats: k2.Fsa, num_paths: int) -> Nbest:
'''Generate an n-best list from a lattice.
Args:
lats:
The decoding lattice from the first pass after LM rescoring.
lats is an FsaVec. It can be the return value of
:func:`whole_lattice_rescoring`
num_paths:
Size of n for n-best list. CAUTION: After removing paths
that represent the same token sequences, the number of paths
in different sequences may not be equal.
Return:
Return an Nbest object. Note the returned FSAs don't have epsilon
self-loops.
'''
assert len(lats.shape) == 3
# CAUTION: We use `phones` instead of `tokens` here because
# :func:`compile_HLG` uses `phones`
#
# Note: compile_HLG is from k2-fsa/snowfall
assert hasattr(lats, 'phones')
assert not hasattr(lats, 'tokens')
lats.tokens = lats.phones
# we use tokens instead of phones in the following code
# First, extract `num_paths` paths for each sequence.
# paths is a k2.RaggedInt with axes [seq][path][arc_pos]
paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
# token_seqs is a k2.RaggedInt sharing the same shape as `paths`
# but it contains token IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1.
# Its axes are [seq][path][token_id]
token_seqs = k2.index(lats.tokens, paths)
# Remove epsilons (0s) and -1 from token_seqs
token_seqs = k2.ragged.remove_values_leq(token_seqs, 0)
# unique_token_seqs is still a k2.RaggedInt with axes [seq][path]token_id].
# But then number of pathsin each sequence may be different.
unique_token_seqs, _, _ = k2.ragged.unique_sequences(
token_seqs, need_num_repeats=False, need_new2old_indexes=False)
seq_to_path_shape = k2.ragged.get_layer(unique_token_seqs.shape(), 0)
# Remove the seq axis.
# Now unique_token_seqs has only two axes [path][token_id]
unique_token_seqs = k2.ragged.remove_axis(unique_token_seqs, 0)
token_fsas = k2.linear_fsa(unique_token_seqs)
return Nbest(fsa=token_fsas, shape=seq_to_path_shape)

View File

@ -5,6 +5,7 @@ import subprocess
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime
from nbest import Nbest
from pathlib import Path
from typing import Dict, Iterable, List, TextIO, Tuple, Union
@ -381,3 +382,75 @@ def write_error_stats(
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate)
def get_best_matching_stats(keys: Nbest, queries: Nbest,
max_order: int) -> torch.Tensor:
'''Get best matching stats on query positions.
Args:
keys:
The nbest after doing second pass rescoring.
queries:
Another nbest before doing second pass rescoring.
max_order:
The maximum n-gram order to ever return by `k2.get_best_matching_stats`
Returns:
A tensor with the shape of [queries.fsa.num_elements, 5], each row
contains the stats (init_score, mean, var, counts_out, ngram_order)
of the token in the correspodding position in queries.
'''
assert keys.shape.dim0() == queries.shape.dim0(), \
f'Utterances number in keys and queries should be equal : \
{keys.shape.dim0()} vs {queries.shape.dim0()}'
# keys_tokens_shape [utt][path][token]
keys_tokens_shape = k2.ragged.compose_ragged_shapes(keys.shape,
k2.ragged.remove_axis(keys.fsa.arcs.shape(), 1))
# queries_tokens_shape [utt][path][token]
queries_tokens_shape = k2.ragged.compose_ragged_shapes(queries.shape,
k2.ragged.remove_axis(queries.fsa.arcs.shape(), 1))
keys_tokens = k2.RaggedInt(keys_tokens_shape, keys.fsa.labels.clone())
queries_tokens = k2.RaggedInt(queries_tokens_shape,
queries.fsa.labels.clone())
# tokens shape [utt][path][token]
tokens = k2.ragged.cat([keys_tokens, queries_tokens], axis=1)
keys_token_num = keys.fsa.labels.size()[0]
queries_tokens_num = queries.fsa.labels.size()[0]
# counts on key positions are ones
keys_counts = k2.RaggedInt(keys_tokens_shape,
torch.ones(keys_token_num,
dtype=torch.int32))
# counts on query positions are zeros
queries_counts = k2.RaggedInt(queries_tokens_shape,
torch.zeros(queries_tokens_num,
dtype=torch.int32))
counts = k2.ragged.cat([keys_counts, queries_counts], axis=1).values()
# scores on key positions are the scores inherted from nbest path
keys_scores = k2.RaggedFloat(keys_tokens_shape, keys.fsa.scores.clone())
# scores on query positions MUST be zeros
queries_scores = k2.RaggedFloat(queries_tokens_shape,
torch.zeros(queries_tokens_num,
dtype=torch.float32))
scores = k2.ragged.cat([keys_scores, queries_scores], axis=1).values()
# we didn't remove -1 labels before
min_token = -1
eos = -1
max_token = torch.max(torch.max(keys.fsa.labels),
torch.max(queries.fsa.labels))
mean, var, counts_out, ngram = k2.get_best_matching_stats(tokens, scores,
counts, eos, min_token, max_token, max_order)
queries_init_scores = queries.fsa.scores.clone()
# only return the stats on query positions
masking = counts == 0
# shape [queries_tokens_num, 5]
return torch.transpose(torch.stack((queries_init_scores, mean[masking],
var[masking], counts_out[masking],
ngram[masking])), 0, 1)

110
test/test_nbest.py Normal file
View File

@ -0,0 +1,110 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# To run this single test, use
#
# ctest --verbose -R nbest_test_py
import unittest
import k2
import torch
from icefall.nbest import Nbest
class TestNbest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.devices = [torch.device('cpu')]
if torch.cuda.is_available() and k2.with_cuda:
cls.devices.append(torch.device('cuda', 0))
if torch.cuda.device_count() > 1:
torch.cuda.set_device(1)
cls.devices.append(torch.device('cuda', 1))
def test_nbest_constructor(self):
fsa = k2.Fsa.from_str('''
0 1 -1 0.1
1
''')
fsa_vec = k2.create_fsa_vec([fsa, fsa, fsa])
shape = k2.RaggedShape('[[x x] [x]]')
Nbest(fsa_vec, shape)
def test_top_k(self):
fsa0 = k2.Fsa.from_str('''
0 1 -1 0
1
''')
fsas = [fsa0.clone() for i in range(10)]
fsa_vec = k2.create_fsa_vec(fsas)
fsa_vec.scores = torch.tensor([3, 0, 1, 5, 4, 2, 8, 1, 9, 6],
dtype=torch.float)
# 0 1 2 3 4 5 6 7 8 9
# [ [3 0] [1 5 4] [2 8 1 9 6]
shape = k2.RaggedShape('[ [x x] [x x x] [x x x x x] ]')
nbest = Nbest(fsa_vec, shape)
# top_k: k is 1
nbest1 = nbest.top_k(1)
expected_fsa = k2.create_fsa_vec([fsa_vec[0], fsa_vec[3], fsa_vec[8]])
assert str(nbest1.fsa) == str(expected_fsa)
expected_shape = k2.RaggedShape('[ [x] [x] [x] ]')
assert nbest1.shape == expected_shape
# top_k: k is 2
nbest2 = nbest.top_k(2)
expected_fsa = k2.create_fsa_vec([
fsa_vec[0], fsa_vec[1], fsa_vec[3], fsa_vec[4], fsa_vec[8],
fsa_vec[6]
])
assert str(nbest2.fsa) == str(expected_fsa)
expected_shape = k2.RaggedShape('[ [x x] [x x] [x x] ]')
assert nbest2.shape == expected_shape
# top_k: k is 3
nbest3 = nbest.top_k(3)
expected_fsa = k2.create_fsa_vec([
fsa_vec[0], fsa_vec[1], fsa_vec[1], fsa_vec[3], fsa_vec[4],
fsa_vec[2], fsa_vec[8], fsa_vec[6], fsa_vec[9]
])
assert str(nbest3.fsa) == str(expected_fsa)
expected_shape = k2.RaggedShape('[ [x x x] [x x x] [x x x] ]')
assert nbest3.shape == expected_shape
# top_k: k is 4
nbest4 = nbest.top_k(4)
expected_fsa = k2.create_fsa_vec([
fsa_vec[0], fsa_vec[1], fsa_vec[1], fsa_vec[1], fsa_vec[3],
fsa_vec[4], fsa_vec[2], fsa_vec[2], fsa_vec[8], fsa_vec[6],
fsa_vec[9], fsa_vec[5]
])
assert str(nbest4.fsa) == str(expected_fsa)
expected_shape = k2.RaggedShape('[ [x x x x] [x x x x] [x x x x] ]')
assert nbest4.shape == expected_shape
if __name__ == '__main__':
unittest.main()