#!/usr/bin/env python3 # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # # See ../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import k2 import torch from icefall.lm.rescore import ( add_bos, add_eos, make_repeat_map, make_hyp_to_ref_map, make_repeat, compute_alignments, ) def test_add_bos(): bos_id = 100 ragged = k2.RaggedTensor([[1, 2], [3], [0]]) bos_ragged = add_bos(ragged, bos_id) expected = k2.RaggedTensor([[bos_id, 1, 2], [bos_id, 3], [bos_id, 0]]) assert str(bos_ragged) == str(expected) def test_add_eos(): eos_id = 30 ragged = k2.RaggedTensor([[1, 2], [3], [], [5, 8, 9]]) ragged_eos = add_eos(ragged, eos_id) expected = k2.RaggedTensor( [[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]] ) def test_pad(): bos_id = 10 eos_id = 100 ragged = k2.RaggedTensor([[1, 2, 3], [5], [9, 8]]) bos_ragged = add_bos(ragged, bos_id) bos_ragged_eos = add_eos(bos_ragged, eos_id) blank_id = -1 padded = bos_ragged_eos.pad(mode="constant", padding_value=blank_id) expected = torch.tensor( [ [bos_id, 1, 2, 3, eos_id], [bos_id, 5, eos_id, blank_id, blank_id], [bos_id, 9, 8, eos_id, blank_id], ] ).to(padded) assert torch.all(torch.eq(padded, expected)) def test_make_hyp_to_ref_map(): a = k2.RaggedTensor([[[1, 2], [], [3]], [[1, 3], [2], [4], [5]]]) row_splits = a.shape.row_splits(1) repeat_map = make_hyp_to_ref_map(row_splits) # fmt: off expected = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6]).to(repeat_map) # fmt: on assert torch.all(torch.eq(repeat_map, expected)) def test_make_repeat_map(): a = k2.RaggedTensor([[[1, 2], [], [3]], [[1, 3], [2], [4], [5]]]) row_splits = a.shape.row_splits(1) repeat_map = make_repeat_map(row_splits) # fmt: off expected = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6]).to(repeat_map) # fmt: on assert torch.all(torch.eq(repeat_map, expected)) def test_make_repeat(): # fmt: off a = k2.RaggedTensor([ [[1, 3, 5], [2, 6]], [[1, 2, 3, 4], [2], [], [9, 10, 11]], ]) b = make_repeat(a) expected = k2.RaggedTensor([ [[1, 3, 5], [2, 6], [1, 3, 5], [2, 6]], [[1, 2, 3, 4], [2], [], [9, 10, 11], [1, 2, 3, 4], [2], [], [9, 10, 11], [1, 2, 3, 4], [2], [], [9, 10, 11], [1, 2, 3, 4], [2], [], [9, 10, 11]], ]) # fmt: on assert str(b) == str(expected) def test_compute_alignments(): # fmt: off tokens = k2.RaggedTensor([ # utt 0 [1, 3], [8], [2], # utt 1 [1, 5], [9], ]) shape = k2.RaggedShape('[[x x x] [x x]]') # fmt: on alignment = compute_alignments(tokens, shape) print("maksed_src:", alignment.labels) print("src:", alignment.hyp_labels) print("tgt:", alignment.ref_labels) def main(): test_add_bos() test_add_eos() test_pad() test_make_repeat_map() test_make_hyp_to_ref_map() test_make_repeat() test_compute_alignments() if __name__ == "__main__": main()