mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-08 17:42:19 +00:00
minor fixes.
This commit is contained in:
parent
efcd8ab92c
commit
0dacee71be
@ -38,7 +38,7 @@ void PybindFrameExtractionOptions(py::module &m) {
|
||||
m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"),
|
||||
py::arg("flush") = true);
|
||||
|
||||
m.def("get_strided", &GetStrided, py::arg("ave"), py::arg("opts"));
|
||||
m.def("get_strided", &GetStrided, py::arg("wave"), py::arg("opts"));
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
@ -135,32 +135,37 @@ def test_use_energy_htk_compat_false():
|
||||
|
||||
|
||||
def test_compute_batch():
|
||||
data1 = read_wave()
|
||||
data2 = read_wave()
|
||||
devices = [torch.device('cpu')]
|
||||
if torch.cuda.is_available():
|
||||
devices.append(torch.device('cuda', 0))
|
||||
|
||||
data = [data1, data2]
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
for device in devices:
|
||||
data1 = read_wave().to(device)
|
||||
data2 = read_wave().to(device)
|
||||
|
||||
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
num_frames = [
|
||||
_kaldifeat.num_frames(w.numel(), fbank_opts.frame_opts)
|
||||
for w in waves
|
||||
]
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank_opts.set_device(device)
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
|
||||
strided = [
|
||||
_kaldifeat.get_strided(w, fbank_opts.frame_opts) for w in waves
|
||||
]
|
||||
strided = torch.cat(strided, dim=0)
|
||||
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
num_frames = [
|
||||
_kaldifeat.num_frames(w.numel(), fbank_opts.frame_opts)
|
||||
for w in waves
|
||||
]
|
||||
|
||||
features = _kaldifeat.compute(strided, fbank)
|
||||
feature1 = features[:num_frames[0]]
|
||||
feature2 = features[num_frames[0]:]
|
||||
return [feature1, feature2]
|
||||
strided = [
|
||||
_kaldifeat.get_strided(w, fbank_opts.frame_opts) for w in waves
|
||||
]
|
||||
strided = torch.cat(strided, dim=0)
|
||||
|
||||
feature1, feature2 = impl([data1, data2])
|
||||
assert torch.allclose(feature1, feature2)
|
||||
features = _kaldifeat.compute(strided, fbank)
|
||||
feature1 = features[:num_frames[0]]
|
||||
feature2 = features[num_frames[0]:]
|
||||
return [feature1, feature2]
|
||||
|
||||
feature1, feature2 = impl([data1, data2])
|
||||
assert torch.allclose(feature1, feature2)
|
||||
|
||||
|
||||
def main():
|
||||
|
Loading…
x
Reference in New Issue
Block a user