minor fixes.

This commit is contained in:
Fangjun Kuang 2021-02-28 00:20:02 +08:00
parent efcd8ab92c
commit 0dacee71be
2 changed files with 27 additions and 22 deletions

View File

@ -38,7 +38,7 @@ void PybindFrameExtractionOptions(py::module &m) {
m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"), m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"),
py::arg("flush") = true); py::arg("flush") = true);
m.def("get_strided", &GetStrided, py::arg("ave"), py::arg("opts")); m.def("get_strided", &GetStrided, py::arg("wave"), py::arg("opts"));
} }
} // namespace kaldifeat } // namespace kaldifeat

View File

@ -135,32 +135,37 @@ def test_use_energy_htk_compat_false():
def test_compute_batch(): def test_compute_batch():
data1 = read_wave() devices = [torch.device('cpu')]
data2 = read_wave() if torch.cuda.is_available():
devices.append(torch.device('cuda', 0))
data = [data1, data2] for device in devices:
fbank_opts = _kaldifeat.FbankOptions() data1 = read_wave().to(device)
fbank_opts.frame_opts.dither = 0 data2 = read_wave().to(device)
fbank = _kaldifeat.Fbank(fbank_opts)
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]: fbank_opts = _kaldifeat.FbankOptions()
num_frames = [ fbank_opts.frame_opts.dither = 0
_kaldifeat.num_frames(w.numel(), fbank_opts.frame_opts) fbank_opts.set_device(device)
for w in waves fbank = _kaldifeat.Fbank(fbank_opts)
]
strided = [ def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]:
_kaldifeat.get_strided(w, fbank_opts.frame_opts) for w in waves num_frames = [
] _kaldifeat.num_frames(w.numel(), fbank_opts.frame_opts)
strided = torch.cat(strided, dim=0) for w in waves
]
features = _kaldifeat.compute(strided, fbank) strided = [
feature1 = features[:num_frames[0]] _kaldifeat.get_strided(w, fbank_opts.frame_opts) for w in waves
feature2 = features[num_frames[0]:] ]
return [feature1, feature2] strided = torch.cat(strided, dim=0)
feature1, feature2 = impl([data1, data2]) features = _kaldifeat.compute(strided, fbank)
assert torch.allclose(feature1, feature2) feature1 = features[:num_frames[0]]
feature2 = features[num_frames[0]:]
return [feature1, feature2]
feature1, feature2 = impl([data1, data2])
assert torch.allclose(feature1, feature2)
def main(): def main():