mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Copy the files related to multi round nbest rescoring from k2 & snowfall
This commit is contained in:
parent
cf8d76293d
commit
cabe8b625b
264
icefall/nbest.py
Normal file
264
icefall/nbest.py
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
# This file implements the ideas proposed by Daniel Povey.
|
||||||
|
#
|
||||||
|
# See https://github.com/k2-fsa/snowfall/issues/232 for more details
|
||||||
|
#
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import _k2
|
||||||
|
import k2
|
||||||
|
|
||||||
|
# Note: We use `utterance` and `sequence` interchangeably in the comment
|
||||||
|
|
||||||
|
|
||||||
|
class Nbest(object):
|
||||||
|
'''
|
||||||
|
An Nbest object contains two fields:
|
||||||
|
|
||||||
|
(1) fsa, its type is k2.Fsa
|
||||||
|
(2) shape, its type is k2.RaggedShape (alias to _k2.RaggedShape)
|
||||||
|
|
||||||
|
The field `fsa` is an FsaVec containing a vector of **linear** FSAs.
|
||||||
|
|
||||||
|
The field `shape` has two axes [utt][path]. `shape.dim0()` contains
|
||||||
|
the number of utterances, which is also the number of rows in the
|
||||||
|
supervision_segments. `shape.tot_size(1)` contains the number
|
||||||
|
of paths, which is also the number of FSAs in `fsa`.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, fsa: k2.Fsa, shape: _k2.RaggedShape) -> None:
|
||||||
|
assert len(fsa.shape) == 3, f'fsa.shape: {fsa.shape}'
|
||||||
|
assert shape.num_axes() == 2, f'num_axes: {shape.num_axes()}'
|
||||||
|
|
||||||
|
assert fsa.shape[0] == shape.tot_size(1), \
|
||||||
|
f'{fsa.shape[0]} vs {shape.tot_size(1)}'
|
||||||
|
|
||||||
|
self.fsa = fsa
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
s = 'Nbest('
|
||||||
|
s += f'num_seqs:{self.shape.dim0()}, '
|
||||||
|
s += f'num_fsas:{self.fsa.shape[0]})'
|
||||||
|
return s
|
||||||
|
|
||||||
|
def intersect(self, lats: k2.Fsa) -> 'Nbest':
|
||||||
|
'''Intersect this Nbest object with a lattice and get 1-best
|
||||||
|
path from the resulting FsaVec.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
We assume FSAs in `self.fsa` don't have epsilon self-loops.
|
||||||
|
We also assume `self.fsa.labels` and `lats.labels` are token IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lats:
|
||||||
|
An FsaVec. It can be the return value of
|
||||||
|
:func:`whole_lattice_rescoring`.
|
||||||
|
Returns:
|
||||||
|
Return a new Nbest. This new Nbest shares the same shape with `self`,
|
||||||
|
while its `fsa` is the 1-best path from intersecting `self.fsa` and
|
||||||
|
`lats.
|
||||||
|
'''
|
||||||
|
assert self.fsa.device == lats.device, \
|
||||||
|
f'{self.fsa.device} vs {lats.device}'
|
||||||
|
assert len(lats.shape) == 3, f'{lats.shape}'
|
||||||
|
assert lats.arcs.dim0() == self.shape.dim0(), \
|
||||||
|
f'{lats.arcs.dim0()} vs {self.shape.dim0()}'
|
||||||
|
|
||||||
|
lats = k2.arc_sort(lats) # no-op if lats is already arc sorted
|
||||||
|
|
||||||
|
fsas_with_epsilon_loops = k2.add_epsilon_self_loops(self.fsa)
|
||||||
|
|
||||||
|
path_to_seq_map = self.shape.row_ids(1)
|
||||||
|
|
||||||
|
ans_lats = k2.intersect_device(a_fsas=lats,
|
||||||
|
b_fsas=fsas_with_epsilon_loops,
|
||||||
|
b_to_a_map=path_to_seq_map,
|
||||||
|
sorted_match_a=True)
|
||||||
|
|
||||||
|
one_best = k2.shortest_path(ans_lats, use_double_scores=True)
|
||||||
|
|
||||||
|
one_best = k2.remove_epsilon(one_best)
|
||||||
|
|
||||||
|
return Nbest(fsa=one_best, shape=self.shape)
|
||||||
|
|
||||||
|
def total_scores(self) -> _k2.RaggedFloat:
|
||||||
|
'''Get total scores of the FSAs in this Nbest.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Since FSAs in Nbest are just linear FSAs, log-semirng and tropical
|
||||||
|
semiring produce the same total scores.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a ragged tensor with two axes [utt][path_scores].
|
||||||
|
'''
|
||||||
|
scores = self.fsa.get_tot_scores(use_double_scores=True,
|
||||||
|
log_semiring=False)
|
||||||
|
# We use single precision here since we only wrap k2.RaggedFloat.
|
||||||
|
# If k2.RaggedDouble is wrapped, we can use double precision here.
|
||||||
|
return _k2.RaggedFloat(self.shape, scores.float())
|
||||||
|
|
||||||
|
def top_k(self, k: int) -> 'Nbest':
|
||||||
|
'''Get a subset of paths in the Nbest. The resulting Nbest is regular
|
||||||
|
in that each sequence (i.e., utterance) has the same number of
|
||||||
|
paths (k).
|
||||||
|
|
||||||
|
We select the top-k paths according to the total_scores of each path.
|
||||||
|
If a utterance has less than k paths, then its last path, after sorting
|
||||||
|
by tot_scores in descending order, is repeated so that each utterance
|
||||||
|
has exactly k paths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k:
|
||||||
|
Number of paths in each utterance.
|
||||||
|
Returns:
|
||||||
|
Return a new Nbest with a regular shape.
|
||||||
|
'''
|
||||||
|
ragged_scores = self.total_scores()
|
||||||
|
|
||||||
|
# indexes contains idx01's for self.shape
|
||||||
|
# ragged_scores.values()[indexes] is sorted
|
||||||
|
indexes = k2.ragged.sort_sublist(ragged_scores,
|
||||||
|
descending=True,
|
||||||
|
need_new2old_indexes=True)
|
||||||
|
|
||||||
|
ragged_indexes = k2.RaggedInt(self.shape, indexes)
|
||||||
|
|
||||||
|
padded_indexes = k2.ragged.pad(ragged_indexes,
|
||||||
|
mode='replicate',
|
||||||
|
value=-1)
|
||||||
|
assert torch.ge(padded_indexes, 0).all(), \
|
||||||
|
'Some utterances contain empty ' \
|
||||||
|
f'n-best: {self.shape.row_splits(1)}'
|
||||||
|
|
||||||
|
# Select the idx01's of top-k paths of each utterance
|
||||||
|
top_k_indexes = padded_indexes[:, :k].flatten().contiguous()
|
||||||
|
|
||||||
|
top_k_fsas = k2.index_fsa(self.fsa, top_k_indexes)
|
||||||
|
|
||||||
|
top_k_shape = k2.ragged.regular_ragged_shape(dim0=self.shape.dim0(),
|
||||||
|
dim1=k)
|
||||||
|
return Nbest(top_k_fsas, top_k_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def whole_lattice_rescoring(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa) -> k2.Fsa:
|
||||||
|
'''Rescore the 1st pass lattice with an LM.
|
||||||
|
|
||||||
|
In general, the G in HLG used to obtain `lats` is a 3-gram LM.
|
||||||
|
This function replaces the 3-gram LM in `lats` with a 4-gram LM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lats:
|
||||||
|
The decoding lattice from the 1st pass. We assume it is the result
|
||||||
|
of intersecting HLG with the network output.
|
||||||
|
G_with_epsilon_loops:
|
||||||
|
An LM. It is usually a 4-gram LM with epsilon self-loops.
|
||||||
|
It should be arc sorted.
|
||||||
|
Returns:
|
||||||
|
Return a new lattice rescored with a given G.
|
||||||
|
'''
|
||||||
|
assert len(lats.shape) == 3, f'{lats.shape}'
|
||||||
|
assert hasattr(lats, 'lm_scores')
|
||||||
|
assert G_with_epsilon_loops.shape == (1, None, None), \
|
||||||
|
f'{G_with_epsilon_loops.shape}'
|
||||||
|
|
||||||
|
device = lats.device
|
||||||
|
lats.scores = lats.scores - lats.lm_scores
|
||||||
|
# Now lats contains only acoustic scores
|
||||||
|
|
||||||
|
# We will use lm_scores from the given G, so remove lats.lm_scores here
|
||||||
|
del lats.lm_scores
|
||||||
|
assert hasattr(lats, 'lm_scores') is False
|
||||||
|
|
||||||
|
# inverted_lats has word IDs as labels.
|
||||||
|
# Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt
|
||||||
|
# if lats.aux_labels is a ragged tensor
|
||||||
|
inverted_lats = k2.invert(lats)
|
||||||
|
num_seqs = lats.shape[0]
|
||||||
|
|
||||||
|
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
|
||||||
|
inverted_lats,
|
||||||
|
b_to_a_map,
|
||||||
|
sorted_match_a=True)
|
||||||
|
break
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.info(f'Caught exception:\n{e}\n')
|
||||||
|
# Usually, this is an OOM exception. We reduce
|
||||||
|
# the size of the lattice and redo k2.intersect_device()
|
||||||
|
|
||||||
|
# NOTE(fangjun): The choice of the threshold 1e-5 is arbitrary here
|
||||||
|
# to avoid OOM. We may need to fine tune it.
|
||||||
|
logging.info(f'num_arcs before: {inverted_lats.num_arcs}')
|
||||||
|
inverted_lats = k2.prune_on_arc_post(inverted_lats, 1e-5, True)
|
||||||
|
logging.info(f'num_arcs after: {inverted_lats.num_arcs}')
|
||||||
|
|
||||||
|
rescoring_lats = k2.top_sort(k2.connect(rescoring_lats))
|
||||||
|
|
||||||
|
# inv_rescoring_lats has token IDs as labels
|
||||||
|
# and word IDs as aux_labels.
|
||||||
|
inv_rescoring_lats = k2.invert(rescoring_lats)
|
||||||
|
return inv_rescoring_lats
|
||||||
|
|
||||||
|
|
||||||
|
def generate_nbest_list(lats: k2.Fsa, num_paths: int) -> Nbest:
|
||||||
|
'''Generate an n-best list from a lattice.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lats:
|
||||||
|
The decoding lattice from the first pass after LM rescoring.
|
||||||
|
lats is an FsaVec. It can be the return value of
|
||||||
|
:func:`whole_lattice_rescoring`
|
||||||
|
num_paths:
|
||||||
|
Size of n for n-best list. CAUTION: After removing paths
|
||||||
|
that represent the same token sequences, the number of paths
|
||||||
|
in different sequences may not be equal.
|
||||||
|
Return:
|
||||||
|
Return an Nbest object. Note the returned FSAs don't have epsilon
|
||||||
|
self-loops.
|
||||||
|
'''
|
||||||
|
assert len(lats.shape) == 3
|
||||||
|
|
||||||
|
# CAUTION: We use `phones` instead of `tokens` here because
|
||||||
|
# :func:`compile_HLG` uses `phones`
|
||||||
|
#
|
||||||
|
# Note: compile_HLG is from k2-fsa/snowfall
|
||||||
|
assert hasattr(lats, 'phones')
|
||||||
|
|
||||||
|
assert not hasattr(lats, 'tokens')
|
||||||
|
lats.tokens = lats.phones
|
||||||
|
# we use tokens instead of phones in the following code
|
||||||
|
|
||||||
|
# First, extract `num_paths` paths for each sequence.
|
||||||
|
# paths is a k2.RaggedInt with axes [seq][path][arc_pos]
|
||||||
|
paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
|
||||||
|
|
||||||
|
# token_seqs is a k2.RaggedInt sharing the same shape as `paths`
|
||||||
|
# but it contains token IDs. Note that it also contains 0s and -1s.
|
||||||
|
# The last entry in each sublist is -1.
|
||||||
|
# Its axes are [seq][path][token_id]
|
||||||
|
token_seqs = k2.index(lats.tokens, paths)
|
||||||
|
|
||||||
|
# Remove epsilons (0s) and -1 from token_seqs
|
||||||
|
token_seqs = k2.ragged.remove_values_leq(token_seqs, 0)
|
||||||
|
|
||||||
|
# unique_token_seqs is still a k2.RaggedInt with axes [seq][path]token_id].
|
||||||
|
# But then number of pathsin each sequence may be different.
|
||||||
|
unique_token_seqs, _, _ = k2.ragged.unique_sequences(
|
||||||
|
token_seqs, need_num_repeats=False, need_new2old_indexes=False)
|
||||||
|
|
||||||
|
seq_to_path_shape = k2.ragged.get_layer(unique_token_seqs.shape(), 0)
|
||||||
|
|
||||||
|
# Remove the seq axis.
|
||||||
|
# Now unique_token_seqs has only two axes [path][token_id]
|
||||||
|
unique_token_seqs = k2.ragged.remove_axis(unique_token_seqs, 0)
|
||||||
|
|
||||||
|
token_fsas = k2.linear_fsa(unique_token_seqs)
|
||||||
|
|
||||||
|
return Nbest(fsa=token_fsas, shape=seq_to_path_shape)
|
@ -5,6 +5,7 @@ import subprocess
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from nbest import Nbest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||||
|
|
||||||
@ -381,3 +382,75 @@ def write_error_stats(
|
|||||||
|
|
||||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
return float(tot_err_rate)
|
return float(tot_err_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def get_best_matching_stats(keys: Nbest, queries: Nbest,
|
||||||
|
max_order: int) -> torch.Tensor:
|
||||||
|
'''Get best matching stats on query positions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keys:
|
||||||
|
The nbest after doing second pass rescoring.
|
||||||
|
queries:
|
||||||
|
Another nbest before doing second pass rescoring.
|
||||||
|
max_order:
|
||||||
|
The maximum n-gram order to ever return by `k2.get_best_matching_stats`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor with the shape of [queries.fsa.num_elements, 5], each row
|
||||||
|
contains the stats (init_score, mean, var, counts_out, ngram_order)
|
||||||
|
of the token in the correspodding position in queries.
|
||||||
|
'''
|
||||||
|
assert keys.shape.dim0() == queries.shape.dim0(), \
|
||||||
|
f'Utterances number in keys and queries should be equal : \
|
||||||
|
{keys.shape.dim0()} vs {queries.shape.dim0()}'
|
||||||
|
|
||||||
|
# keys_tokens_shape [utt][path][token]
|
||||||
|
keys_tokens_shape = k2.ragged.compose_ragged_shapes(keys.shape,
|
||||||
|
k2.ragged.remove_axis(keys.fsa.arcs.shape(), 1))
|
||||||
|
# queries_tokens_shape [utt][path][token]
|
||||||
|
queries_tokens_shape = k2.ragged.compose_ragged_shapes(queries.shape,
|
||||||
|
k2.ragged.remove_axis(queries.fsa.arcs.shape(), 1))
|
||||||
|
|
||||||
|
keys_tokens = k2.RaggedInt(keys_tokens_shape, keys.fsa.labels.clone())
|
||||||
|
queries_tokens = k2.RaggedInt(queries_tokens_shape,
|
||||||
|
queries.fsa.labels.clone())
|
||||||
|
# tokens shape [utt][path][token]
|
||||||
|
tokens = k2.ragged.cat([keys_tokens, queries_tokens], axis=1)
|
||||||
|
|
||||||
|
keys_token_num = keys.fsa.labels.size()[0]
|
||||||
|
queries_tokens_num = queries.fsa.labels.size()[0]
|
||||||
|
# counts on key positions are ones
|
||||||
|
keys_counts = k2.RaggedInt(keys_tokens_shape,
|
||||||
|
torch.ones(keys_token_num,
|
||||||
|
dtype=torch.int32))
|
||||||
|
# counts on query positions are zeros
|
||||||
|
queries_counts = k2.RaggedInt(queries_tokens_shape,
|
||||||
|
torch.zeros(queries_tokens_num,
|
||||||
|
dtype=torch.int32))
|
||||||
|
counts = k2.ragged.cat([keys_counts, queries_counts], axis=1).values()
|
||||||
|
|
||||||
|
# scores on key positions are the scores inherted from nbest path
|
||||||
|
keys_scores = k2.RaggedFloat(keys_tokens_shape, keys.fsa.scores.clone())
|
||||||
|
# scores on query positions MUST be zeros
|
||||||
|
queries_scores = k2.RaggedFloat(queries_tokens_shape,
|
||||||
|
torch.zeros(queries_tokens_num,
|
||||||
|
dtype=torch.float32))
|
||||||
|
scores = k2.ragged.cat([keys_scores, queries_scores], axis=1).values()
|
||||||
|
|
||||||
|
# we didn't remove -1 labels before
|
||||||
|
min_token = -1
|
||||||
|
eos = -1
|
||||||
|
max_token = torch.max(torch.max(keys.fsa.labels),
|
||||||
|
torch.max(queries.fsa.labels))
|
||||||
|
mean, var, counts_out, ngram = k2.get_best_matching_stats(tokens, scores,
|
||||||
|
counts, eos, min_token, max_token, max_order)
|
||||||
|
|
||||||
|
queries_init_scores = queries.fsa.scores.clone()
|
||||||
|
# only return the stats on query positions
|
||||||
|
masking = counts == 0
|
||||||
|
# shape [queries_tokens_num, 5]
|
||||||
|
return torch.transpose(torch.stack((queries_init_scores, mean[masking],
|
||||||
|
var[masking], counts_out[masking],
|
||||||
|
ngram[masking])), 0, 1)
|
||||||
|
|
||||||
|
110
test/test_nbest.py
Normal file
110
test/test_nbest.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# To run this single test, use
|
||||||
|
#
|
||||||
|
# ctest --verbose -R nbest_test_py
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.nbest import Nbest
|
||||||
|
|
||||||
|
|
||||||
|
class TestNbest(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.devices = [torch.device('cpu')]
|
||||||
|
if torch.cuda.is_available() and k2.with_cuda:
|
||||||
|
cls.devices.append(torch.device('cuda', 0))
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
torch.cuda.set_device(1)
|
||||||
|
cls.devices.append(torch.device('cuda', 1))
|
||||||
|
|
||||||
|
def test_nbest_constructor(self):
|
||||||
|
fsa = k2.Fsa.from_str('''
|
||||||
|
0 1 -1 0.1
|
||||||
|
1
|
||||||
|
''')
|
||||||
|
|
||||||
|
fsa_vec = k2.create_fsa_vec([fsa, fsa, fsa])
|
||||||
|
shape = k2.RaggedShape('[[x x] [x]]')
|
||||||
|
Nbest(fsa_vec, shape)
|
||||||
|
|
||||||
|
def test_top_k(self):
|
||||||
|
fsa0 = k2.Fsa.from_str('''
|
||||||
|
0 1 -1 0
|
||||||
|
1
|
||||||
|
''')
|
||||||
|
fsas = [fsa0.clone() for i in range(10)]
|
||||||
|
fsa_vec = k2.create_fsa_vec(fsas)
|
||||||
|
fsa_vec.scores = torch.tensor([3, 0, 1, 5, 4, 2, 8, 1, 9, 6],
|
||||||
|
dtype=torch.float)
|
||||||
|
# 0 1 2 3 4 5 6 7 8 9
|
||||||
|
# [ [3 0] [1 5 4] [2 8 1 9 6]
|
||||||
|
shape = k2.RaggedShape('[ [x x] [x x x] [x x x x x] ]')
|
||||||
|
nbest = Nbest(fsa_vec, shape)
|
||||||
|
|
||||||
|
# top_k: k is 1
|
||||||
|
nbest1 = nbest.top_k(1)
|
||||||
|
expected_fsa = k2.create_fsa_vec([fsa_vec[0], fsa_vec[3], fsa_vec[8]])
|
||||||
|
assert str(nbest1.fsa) == str(expected_fsa)
|
||||||
|
|
||||||
|
expected_shape = k2.RaggedShape('[ [x] [x] [x] ]')
|
||||||
|
assert nbest1.shape == expected_shape
|
||||||
|
|
||||||
|
# top_k: k is 2
|
||||||
|
nbest2 = nbest.top_k(2)
|
||||||
|
expected_fsa = k2.create_fsa_vec([
|
||||||
|
fsa_vec[0], fsa_vec[1], fsa_vec[3], fsa_vec[4], fsa_vec[8],
|
||||||
|
fsa_vec[6]
|
||||||
|
])
|
||||||
|
assert str(nbest2.fsa) == str(expected_fsa)
|
||||||
|
|
||||||
|
expected_shape = k2.RaggedShape('[ [x x] [x x] [x x] ]')
|
||||||
|
assert nbest2.shape == expected_shape
|
||||||
|
|
||||||
|
# top_k: k is 3
|
||||||
|
nbest3 = nbest.top_k(3)
|
||||||
|
expected_fsa = k2.create_fsa_vec([
|
||||||
|
fsa_vec[0], fsa_vec[1], fsa_vec[1], fsa_vec[3], fsa_vec[4],
|
||||||
|
fsa_vec[2], fsa_vec[8], fsa_vec[6], fsa_vec[9]
|
||||||
|
])
|
||||||
|
assert str(nbest3.fsa) == str(expected_fsa)
|
||||||
|
|
||||||
|
expected_shape = k2.RaggedShape('[ [x x x] [x x x] [x x x] ]')
|
||||||
|
assert nbest3.shape == expected_shape
|
||||||
|
|
||||||
|
# top_k: k is 4
|
||||||
|
nbest4 = nbest.top_k(4)
|
||||||
|
expected_fsa = k2.create_fsa_vec([
|
||||||
|
fsa_vec[0], fsa_vec[1], fsa_vec[1], fsa_vec[1], fsa_vec[3],
|
||||||
|
fsa_vec[4], fsa_vec[2], fsa_vec[2], fsa_vec[8], fsa_vec[6],
|
||||||
|
fsa_vec[9], fsa_vec[5]
|
||||||
|
])
|
||||||
|
assert str(nbest4.fsa) == str(expected_fsa)
|
||||||
|
|
||||||
|
expected_shape = k2.RaggedShape('[ [x x x x] [x x x x] [x x x x] ]')
|
||||||
|
assert nbest4.shape == expected_shape
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user