mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
111 lines
3.5 KiB
Python
111 lines
3.5 KiB
Python
#!/usr/bin/env python3
|
|
#
|
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
#
|
|
# See ../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# To run this single test, use
|
|
#
|
|
# ctest --verbose -R nbest_test_py
|
|
|
|
import unittest
|
|
|
|
import k2
|
|
import torch
|
|
|
|
from icefall.nbest import Nbest
|
|
|
|
|
|
class TestNbest(unittest.TestCase):
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.devices = [torch.device('cpu')]
|
|
if torch.cuda.is_available() and k2.with_cuda:
|
|
cls.devices.append(torch.device('cuda', 0))
|
|
if torch.cuda.device_count() > 1:
|
|
torch.cuda.set_device(1)
|
|
cls.devices.append(torch.device('cuda', 1))
|
|
|
|
def test_nbest_constructor(self):
|
|
fsa = k2.Fsa.from_str('''
|
|
0 1 -1 0.1
|
|
1
|
|
''')
|
|
|
|
fsa_vec = k2.create_fsa_vec([fsa, fsa, fsa])
|
|
shape = k2.RaggedShape('[[x x] [x]]')
|
|
Nbest(fsa_vec, shape)
|
|
|
|
def test_top_k(self):
|
|
fsa0 = k2.Fsa.from_str('''
|
|
0 1 -1 0
|
|
1
|
|
''')
|
|
fsas = [fsa0.clone() for i in range(10)]
|
|
fsa_vec = k2.create_fsa_vec(fsas)
|
|
fsa_vec.scores = torch.tensor([3, 0, 1, 5, 4, 2, 8, 1, 9, 6],
|
|
dtype=torch.float)
|
|
# 0 1 2 3 4 5 6 7 8 9
|
|
# [ [3 0] [1 5 4] [2 8 1 9 6]
|
|
shape = k2.RaggedShape('[ [x x] [x x x] [x x x x x] ]')
|
|
nbest = Nbest(fsa_vec, shape)
|
|
|
|
# top_k: k is 1
|
|
nbest1 = nbest.top_k(1)
|
|
expected_fsa = k2.create_fsa_vec([fsa_vec[0], fsa_vec[3], fsa_vec[8]])
|
|
assert str(nbest1.fsa) == str(expected_fsa)
|
|
|
|
expected_shape = k2.RaggedShape('[ [x] [x] [x] ]')
|
|
assert nbest1.shape == expected_shape
|
|
|
|
# top_k: k is 2
|
|
nbest2 = nbest.top_k(2)
|
|
expected_fsa = k2.create_fsa_vec([
|
|
fsa_vec[0], fsa_vec[1], fsa_vec[3], fsa_vec[4], fsa_vec[8],
|
|
fsa_vec[6]
|
|
])
|
|
assert str(nbest2.fsa) == str(expected_fsa)
|
|
|
|
expected_shape = k2.RaggedShape('[ [x x] [x x] [x x] ]')
|
|
assert nbest2.shape == expected_shape
|
|
|
|
# top_k: k is 3
|
|
nbest3 = nbest.top_k(3)
|
|
expected_fsa = k2.create_fsa_vec([
|
|
fsa_vec[0], fsa_vec[1], fsa_vec[1], fsa_vec[3], fsa_vec[4],
|
|
fsa_vec[2], fsa_vec[8], fsa_vec[6], fsa_vec[9]
|
|
])
|
|
assert str(nbest3.fsa) == str(expected_fsa)
|
|
|
|
expected_shape = k2.RaggedShape('[ [x x x] [x x x] [x x x] ]')
|
|
assert nbest3.shape == expected_shape
|
|
|
|
# top_k: k is 4
|
|
nbest4 = nbest.top_k(4)
|
|
expected_fsa = k2.create_fsa_vec([
|
|
fsa_vec[0], fsa_vec[1], fsa_vec[1], fsa_vec[1], fsa_vec[3],
|
|
fsa_vec[4], fsa_vec[2], fsa_vec[2], fsa_vec[8], fsa_vec[6],
|
|
fsa_vec[9], fsa_vec[5]
|
|
])
|
|
assert str(nbest4.fsa) == str(expected_fsa)
|
|
|
|
expected_shape = k2.RaggedShape('[ [x x x x] [x x x x] [x x x x] ]')
|
|
assert nbest4.shape == expected_shape
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|