mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
166 lines
3.9 KiB
Python
166 lines
3.9 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
#
|
|
# See ../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import k2
|
|
import pytest
|
|
import torch
|
|
|
|
from icefall.env import get_env_info
|
|
from icefall.utils import (
|
|
AttributeDict,
|
|
add_eos,
|
|
add_sos,
|
|
encode_supervisions,
|
|
get_texts,
|
|
make_pad_mask,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def sup():
|
|
sequence_idx = torch.tensor([0, 1, 2])
|
|
start_frame = torch.tensor([1, 3, 9])
|
|
num_frames = torch.tensor([20, 30, 10])
|
|
text = ["one", "two", "three"]
|
|
return {
|
|
"sequence_idx": sequence_idx,
|
|
"start_frame": start_frame,
|
|
"num_frames": num_frames,
|
|
"text": text,
|
|
}
|
|
|
|
|
|
def test_encode_supervisions(sup):
|
|
supervision_segments, texts = encode_supervisions(sup, subsampling_factor=4)
|
|
assert torch.all(
|
|
torch.eq(
|
|
supervision_segments,
|
|
torch.tensor([[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]),
|
|
)
|
|
)
|
|
assert texts == ["two", "one", "three"]
|
|
|
|
|
|
def test_get_texts_ragged():
|
|
fsa1 = k2.Fsa.from_str(
|
|
"""
|
|
0 1 1 10
|
|
1 2 2 20
|
|
2 3 3 30
|
|
3 4 -1 0
|
|
4
|
|
"""
|
|
)
|
|
fsa1.aux_labels = k2.RaggedTensor("[ [1 3 0 2] [] [4 0 1] [-1]]")
|
|
|
|
fsa2 = k2.Fsa.from_str(
|
|
"""
|
|
0 1 1 1
|
|
1 2 2 2
|
|
2 3 -1 0
|
|
3
|
|
"""
|
|
)
|
|
fsa2.aux_labels = k2.RaggedTensor("[[3 0 5 0 8] [0 9 7 0] [-1]]")
|
|
fsas = k2.Fsa.from_fsas([fsa1, fsa2])
|
|
texts = get_texts(fsas)
|
|
assert texts == [[1, 3, 2, 4, 1], [3, 5, 8, 9, 7]]
|
|
|
|
|
|
def test_get_texts_regular():
|
|
fsa1 = k2.Fsa.from_str(
|
|
"""
|
|
0 1 1 3 10
|
|
1 2 2 0 20
|
|
2 3 3 2 30
|
|
3 4 -1 -1 0
|
|
4
|
|
""",
|
|
num_aux_labels=1,
|
|
)
|
|
|
|
fsa2 = k2.Fsa.from_str(
|
|
"""
|
|
0 1 1 10 1
|
|
1 2 2 5 2
|
|
2 3 -1 -1 0
|
|
3
|
|
""",
|
|
num_aux_labels=1,
|
|
)
|
|
fsas = k2.Fsa.from_fsas([fsa1, fsa2])
|
|
texts = get_texts(fsas)
|
|
assert texts == [[3, 2], [10, 5]]
|
|
|
|
|
|
def test_attribute_dict():
|
|
s = AttributeDict({"a": 10, "b": 20})
|
|
assert s.a == 10
|
|
assert s["b"] == 20
|
|
s.c = 100
|
|
assert s["c"] == 100
|
|
|
|
assert hasattr(s, "a")
|
|
assert hasattr(s, "b")
|
|
assert getattr(s, "a") == 10
|
|
del s.a
|
|
assert hasattr(s, "a") is False
|
|
setattr(s, "c", 100)
|
|
s.c = 100
|
|
try:
|
|
del s.a
|
|
except AttributeError as ex:
|
|
print(f"Caught exception: {ex}")
|
|
|
|
|
|
def test_get_env_info():
|
|
s = get_env_info()
|
|
print(s)
|
|
|
|
|
|
def test_makd_pad_mask():
|
|
lengths = torch.tensor([1, 3, 2])
|
|
mask = make_pad_mask(lengths)
|
|
expected = torch.tensor(
|
|
[
|
|
[False, True, True],
|
|
[False, False, False],
|
|
[False, False, True],
|
|
]
|
|
)
|
|
assert torch.all(torch.eq(mask, expected))
|
|
assert (~expected).sum() == lengths.sum()
|
|
|
|
|
|
def test_add_sos():
|
|
sos_id = 100
|
|
ragged = k2.RaggedTensor([[1, 2], [3], [0]])
|
|
sos_ragged = add_sos(ragged, sos_id)
|
|
expected = k2.RaggedTensor([[sos_id, 1, 2], [sos_id, 3], [sos_id, 0]])
|
|
assert str(sos_ragged) == str(expected)
|
|
|
|
|
|
def test_add_eos():
|
|
eos_id = 30
|
|
ragged = k2.RaggedTensor([[1, 2], [3], [], [5, 8, 9]])
|
|
ragged_eos = add_eos(ragged, eos_id)
|
|
expected = k2.RaggedTensor(
|
|
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]]
|
|
)
|
|
assert str(ragged_eos) == str(expected)
|