mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
134 lines
3.7 KiB
Python
Executable File
134 lines
3.7 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
#
|
|
# See ../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import k2
|
|
import torch
|
|
|
|
from icefall.lm.rescore import (
|
|
add_bos,
|
|
add_eos,
|
|
make_repeat_map,
|
|
make_hyp_to_ref_map,
|
|
make_repeat,
|
|
compute_alignments,
|
|
)
|
|
|
|
|
|
def test_add_bos():
|
|
bos_id = 100
|
|
ragged = k2.RaggedTensor([[1, 2], [3], [0]])
|
|
bos_ragged = add_bos(ragged, bos_id)
|
|
expected = k2.RaggedTensor([[bos_id, 1, 2], [bos_id, 3], [bos_id, 0]])
|
|
assert str(bos_ragged) == str(expected)
|
|
|
|
|
|
def test_add_eos():
|
|
eos_id = 30
|
|
ragged = k2.RaggedTensor([[1, 2], [3], [], [5, 8, 9]])
|
|
ragged_eos = add_eos(ragged, eos_id)
|
|
expected = k2.RaggedTensor(
|
|
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]]
|
|
)
|
|
|
|
|
|
def test_pad():
|
|
bos_id = 10
|
|
eos_id = 100
|
|
ragged = k2.RaggedTensor([[1, 2, 3], [5], [9, 8]])
|
|
bos_ragged = add_bos(ragged, bos_id)
|
|
bos_ragged_eos = add_eos(bos_ragged, eos_id)
|
|
blank_id = -1
|
|
padded = bos_ragged_eos.pad(mode="constant", padding_value=blank_id)
|
|
expected = torch.tensor(
|
|
[
|
|
[bos_id, 1, 2, 3, eos_id],
|
|
[bos_id, 5, eos_id, blank_id, blank_id],
|
|
[bos_id, 9, 8, eos_id, blank_id],
|
|
]
|
|
).to(padded)
|
|
assert torch.all(torch.eq(padded, expected))
|
|
|
|
|
|
def test_make_hyp_to_ref_map():
|
|
a = k2.RaggedTensor([[[1, 2], [], [3]], [[1, 3], [2], [4], [5]]])
|
|
row_splits = a.shape.row_splits(1)
|
|
repeat_map = make_hyp_to_ref_map(row_splits)
|
|
# fmt: off
|
|
expected = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3,
|
|
3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6]).to(repeat_map)
|
|
# fmt: on
|
|
assert torch.all(torch.eq(repeat_map, expected))
|
|
|
|
|
|
def test_make_repeat_map():
|
|
a = k2.RaggedTensor([[[1, 2], [], [3]], [[1, 3], [2], [4], [5]]])
|
|
row_splits = a.shape.row_splits(1)
|
|
repeat_map = make_repeat_map(row_splits)
|
|
# fmt: off
|
|
expected = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2,
|
|
3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6,
|
|
3, 4, 5, 6]).to(repeat_map)
|
|
# fmt: on
|
|
assert torch.all(torch.eq(repeat_map, expected))
|
|
|
|
|
|
def test_make_repeat():
|
|
# fmt: off
|
|
a = k2.RaggedTensor([
|
|
[[1, 3, 5], [2, 6]],
|
|
[[1, 2, 3, 4], [2], [], [9, 10, 11]],
|
|
])
|
|
b = make_repeat(a)
|
|
expected = k2.RaggedTensor([
|
|
[[1, 3, 5], [2, 6], [1, 3, 5], [2, 6]],
|
|
[[1, 2, 3, 4], [2], [], [9, 10, 11],
|
|
[1, 2, 3, 4], [2], [], [9, 10, 11],
|
|
[1, 2, 3, 4], [2], [], [9, 10, 11],
|
|
[1, 2, 3, 4], [2], [], [9, 10, 11]],
|
|
])
|
|
# fmt: on
|
|
assert str(b) == str(expected)
|
|
|
|
|
|
def test_compute_alignments():
|
|
# fmt: off
|
|
tokens = k2.RaggedTensor([
|
|
# utt 0
|
|
[1, 3], [8], [2],
|
|
# utt 1
|
|
[1, 5], [9],
|
|
])
|
|
shape = k2.RaggedShape('[[x x x] [x x]]')
|
|
# fmt: on
|
|
alignment = compute_alignments(tokens, shape)
|
|
print("maksed_src:", alignment.labels)
|
|
print("src:", alignment.hyp_labels)
|
|
print("tgt:", alignment.ref_labels)
|
|
|
|
|
|
def main():
|
|
test_add_bos()
|
|
test_add_eos()
|
|
test_pad()
|
|
test_make_repeat_map()
|
|
test_make_hyp_to_ref_map()
|
|
test_make_repeat()
|
|
test_compute_alignments()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|