minor fixes.

This commit is contained in:
Fangjun Kuang 2021-02-28 00:20:02 +08:00
parent efcd8ab92c
commit 0dacee71be
2 changed files with 27 additions and 22 deletions

View File

@ -38,7 +38,7 @@ void PybindFrameExtractionOptions(py::module &m) {
m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"),
py::arg("flush") = true);
m.def("get_strided", &GetStrided, py::arg("ave"), py::arg("opts"));
m.def("get_strided", &GetStrided, py::arg("wave"), py::arg("opts"));
}
} // namespace kaldifeat

View File

@ -135,32 +135,37 @@ def test_use_energy_htk_compat_false():
def test_compute_batch():
data1 = read_wave()
data2 = read_wave()
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices.append(torch.device('cuda', 0))
data = [data1, data2]
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank = _kaldifeat.Fbank(fbank_opts)
for device in devices:
data1 = read_wave().to(device)
data2 = read_wave().to(device)
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]:
num_frames = [
_kaldifeat.num_frames(w.numel(), fbank_opts.frame_opts)
for w in waves
]
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.set_device(device)
fbank = _kaldifeat.Fbank(fbank_opts)
strided = [
_kaldifeat.get_strided(w, fbank_opts.frame_opts) for w in waves
]
strided = torch.cat(strided, dim=0)
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]:
num_frames = [
_kaldifeat.num_frames(w.numel(), fbank_opts.frame_opts)
for w in waves
]
features = _kaldifeat.compute(strided, fbank)
feature1 = features[:num_frames[0]]
feature2 = features[num_frames[0]:]
return [feature1, feature2]
strided = [
_kaldifeat.get_strided(w, fbank_opts.frame_opts) for w in waves
]
strided = torch.cat(strided, dim=0)
feature1, feature2 = impl([data1, data2])
assert torch.allclose(feature1, feature2)
features = _kaldifeat.compute(strided, fbank)
feature1 = features[:num_frames[0]]
feature2 = features[num_frames[0]:]
return [feature1, feature2]
feature1, feature2 = impl([data1, data2])
assert torch.allclose(feature1, feature2)
def main():