mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 01:52:39 +00:00
minor fixes.
This commit is contained in:
parent
efcd8ab92c
commit
0dacee71be
@ -38,7 +38,7 @@ void PybindFrameExtractionOptions(py::module &m) {
|
|||||||
m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"),
|
m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"),
|
||||||
py::arg("flush") = true);
|
py::arg("flush") = true);
|
||||||
|
|
||||||
m.def("get_strided", &GetStrided, py::arg("ave"), py::arg("opts"));
|
m.def("get_strided", &GetStrided, py::arg("wave"), py::arg("opts"));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace kaldifeat
|
} // namespace kaldifeat
|
||||||
|
@ -135,32 +135,37 @@ def test_use_energy_htk_compat_false():
|
|||||||
|
|
||||||
|
|
||||||
def test_compute_batch():
|
def test_compute_batch():
|
||||||
data1 = read_wave()
|
devices = [torch.device('cpu')]
|
||||||
data2 = read_wave()
|
if torch.cuda.is_available():
|
||||||
|
devices.append(torch.device('cuda', 0))
|
||||||
|
|
||||||
data = [data1, data2]
|
for device in devices:
|
||||||
fbank_opts = _kaldifeat.FbankOptions()
|
data1 = read_wave().to(device)
|
||||||
fbank_opts.frame_opts.dither = 0
|
data2 = read_wave().to(device)
|
||||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
|
||||||
|
|
||||||
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]:
|
fbank_opts = _kaldifeat.FbankOptions()
|
||||||
num_frames = [
|
fbank_opts.frame_opts.dither = 0
|
||||||
_kaldifeat.num_frames(w.numel(), fbank_opts.frame_opts)
|
fbank_opts.set_device(device)
|
||||||
for w in waves
|
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||||
]
|
|
||||||
|
|
||||||
strided = [
|
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||||
_kaldifeat.get_strided(w, fbank_opts.frame_opts) for w in waves
|
num_frames = [
|
||||||
]
|
_kaldifeat.num_frames(w.numel(), fbank_opts.frame_opts)
|
||||||
strided = torch.cat(strided, dim=0)
|
for w in waves
|
||||||
|
]
|
||||||
|
|
||||||
features = _kaldifeat.compute(strided, fbank)
|
strided = [
|
||||||
feature1 = features[:num_frames[0]]
|
_kaldifeat.get_strided(w, fbank_opts.frame_opts) for w in waves
|
||||||
feature2 = features[num_frames[0]:]
|
]
|
||||||
return [feature1, feature2]
|
strided = torch.cat(strided, dim=0)
|
||||||
|
|
||||||
feature1, feature2 = impl([data1, data2])
|
features = _kaldifeat.compute(strided, fbank)
|
||||||
assert torch.allclose(feature1, feature2)
|
feature1 = features[:num_frames[0]]
|
||||||
|
feature2 = features[num_frames[0]:]
|
||||||
|
return [feature1, feature2]
|
||||||
|
|
||||||
|
feature1, feature2 = impl([data1, data2])
|
||||||
|
assert torch.allclose(feature1, feature2)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user