icefall/test/test_nbest.py

111 lines
3.5 KiB
Python

#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# To run this single test, use
#
# ctest --verbose -R nbest_test_py
import unittest
import k2
import torch
from icefall.nbest import Nbest
class TestNbest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.devices = [torch.device('cpu')]
if torch.cuda.is_available() and k2.with_cuda:
cls.devices.append(torch.device('cuda', 0))
if torch.cuda.device_count() > 1:
torch.cuda.set_device(1)
cls.devices.append(torch.device('cuda', 1))
def test_nbest_constructor(self):
fsa = k2.Fsa.from_str('''
0 1 -1 0.1
1
''')
fsa_vec = k2.create_fsa_vec([fsa, fsa, fsa])
shape = k2.RaggedShape('[[x x] [x]]')
Nbest(fsa_vec, shape)
def test_top_k(self):
fsa0 = k2.Fsa.from_str('''
0 1 -1 0
1
''')
fsas = [fsa0.clone() for i in range(10)]
fsa_vec = k2.create_fsa_vec(fsas)
fsa_vec.scores = torch.tensor([3, 0, 1, 5, 4, 2, 8, 1, 9, 6],
dtype=torch.float)
# 0 1 2 3 4 5 6 7 8 9
# [ [3 0] [1 5 4] [2 8 1 9 6]
shape = k2.RaggedShape('[ [x x] [x x x] [x x x x x] ]')
nbest = Nbest(fsa_vec, shape)
# top_k: k is 1
nbest1 = nbest.top_k(1)
expected_fsa = k2.create_fsa_vec([fsa_vec[0], fsa_vec[3], fsa_vec[8]])
assert str(nbest1.fsa) == str(expected_fsa)
expected_shape = k2.RaggedShape('[ [x] [x] [x] ]')
assert nbest1.shape == expected_shape
# top_k: k is 2
nbest2 = nbest.top_k(2)
expected_fsa = k2.create_fsa_vec([
fsa_vec[0], fsa_vec[1], fsa_vec[3], fsa_vec[4], fsa_vec[8],
fsa_vec[6]
])
assert str(nbest2.fsa) == str(expected_fsa)
expected_shape = k2.RaggedShape('[ [x x] [x x] [x x] ]')
assert nbest2.shape == expected_shape
# top_k: k is 3
nbest3 = nbest.top_k(3)
expected_fsa = k2.create_fsa_vec([
fsa_vec[0], fsa_vec[1], fsa_vec[1], fsa_vec[3], fsa_vec[4],
fsa_vec[2], fsa_vec[8], fsa_vec[6], fsa_vec[9]
])
assert str(nbest3.fsa) == str(expected_fsa)
expected_shape = k2.RaggedShape('[ [x x x] [x x x] [x x x] ]')
assert nbest3.shape == expected_shape
# top_k: k is 4
nbest4 = nbest.top_k(4)
expected_fsa = k2.create_fsa_vec([
fsa_vec[0], fsa_vec[1], fsa_vec[1], fsa_vec[1], fsa_vec[3],
fsa_vec[4], fsa_vec[2], fsa_vec[2], fsa_vec[8], fsa_vec[6],
fsa_vec[9], fsa_vec[5]
])
assert str(nbest4.fsa) == str(expected_fsa)
expected_shape = k2.RaggedShape('[ [x x x x] [x x x x] [x x x x] ]')
assert nbest4.shape == expected_shape
if __name__ == '__main__':
unittest.main()