mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-10 10:32:16 +00:00
Add kaldifeat.Fbank
This commit is contained in:
parent
be8442fd68
commit
e043afe3d6
4
.flake8
4
.flake8
@ -1,5 +1,9 @@
|
|||||||
[flake8]
|
[flake8]
|
||||||
max-line-length = 80
|
max-line-length = 80
|
||||||
|
|
||||||
|
exclude =
|
||||||
|
.git,
|
||||||
|
kaldifeat/python/kaldifeat/__init__.py
|
||||||
|
|
||||||
ignore =
|
ignore =
|
||||||
E402
|
E402
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,3 +2,4 @@ build/
|
|||||||
build*/
|
build*/
|
||||||
*.egg-info*/
|
*.egg-info*/
|
||||||
dist/
|
dist/
|
||||||
|
__pycache__/
|
||||||
|
@ -9,6 +9,12 @@ set(kaldifeat_srcs
|
|||||||
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
|
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
|
||||||
target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES})
|
target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES})
|
||||||
|
|
||||||
|
# PYTHON_INCLUDE_DIRS is set by pybind11
|
||||||
|
target_include_directories(kaldifeat_core PUBLIC ${PYTHON_INCLUDE_DIRS})
|
||||||
|
|
||||||
|
# PYTHON_LIBRARY is set by pybind11
|
||||||
|
target_link_libraries(kaldifeat_core PUBLIC ${PYTHON_LIBRARY})
|
||||||
|
|
||||||
add_executable(test_kaldifeat test_kaldifeat.cc)
|
add_executable(test_kaldifeat test_kaldifeat.cc)
|
||||||
target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core)
|
target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core)
|
||||||
|
|
||||||
|
@ -8,12 +8,16 @@
|
|||||||
#define KALDIFEAT_CSRC_FEATURE_FBANK_H_
|
#define KALDIFEAT_CSRC_FEATURE_FBANK_H_
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "kaldifeat/csrc/feature-common.h"
|
#include "kaldifeat/csrc/feature-common.h"
|
||||||
#include "kaldifeat/csrc/feature-window.h"
|
#include "kaldifeat/csrc/feature-window.h"
|
||||||
#include "kaldifeat/csrc/mel-computations.h"
|
#include "kaldifeat/csrc/mel-computations.h"
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
#include "torch/torch.h"
|
#include "torch/torch.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
|
|
||||||
struct FbankOptions {
|
struct FbankOptions {
|
||||||
@ -42,6 +46,16 @@ struct FbankOptions {
|
|||||||
|
|
||||||
FbankOptions() : device("cpu") { mel_opts.num_bins = 23; }
|
FbankOptions() : device("cpu") { mel_opts.num_bins = 23; }
|
||||||
|
|
||||||
|
// Get/Set methods are for implementing properties in Python
|
||||||
|
py::object GetDevice() const {
|
||||||
|
py::object ans = py::module_::import("torch").attr("device");
|
||||||
|
return ans(device.str());
|
||||||
|
}
|
||||||
|
void SetDevice(py::object obj) {
|
||||||
|
std::string s = static_cast<py::str>(obj);
|
||||||
|
device = torch::Device(s);
|
||||||
|
}
|
||||||
|
|
||||||
std::string ToString() const {
|
std::string ToString() const {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << "frame_opts: \n";
|
os << "frame_opts: \n";
|
||||||
|
@ -19,22 +19,8 @@ void PybindFbankOptions(py::module &m) {
|
|||||||
.def_readwrite("htk_compat", &FbankOptions::htk_compat)
|
.def_readwrite("htk_compat", &FbankOptions::htk_compat)
|
||||||
.def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank)
|
.def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank)
|
||||||
.def_readwrite("use_power", &FbankOptions::use_power)
|
.def_readwrite("use_power", &FbankOptions::use_power)
|
||||||
.def("set_device",
|
.def_property("device", &FbankOptions::GetDevice,
|
||||||
[](FbankOptions *fbank_opts, py::object device) {
|
&FbankOptions::SetDevice)
|
||||||
std::string device_type =
|
|
||||||
static_cast<py::str>(device.attr("type"));
|
|
||||||
KALDIFEAT_ASSERT(device_type == "cpu" || device_type == "cuda")
|
|
||||||
<< "Unsupported device type: " << device_type;
|
|
||||||
|
|
||||||
auto index_attr = static_cast<py::object>(device.attr("index"));
|
|
||||||
int32_t device_index = 0;
|
|
||||||
if (!index_attr.is_none())
|
|
||||||
device_index = static_cast<py::int_>(index_attr);
|
|
||||||
if (device_type == "cpu")
|
|
||||||
fbank_opts->device = torch::Device("cpu");
|
|
||||||
else
|
|
||||||
fbank_opts->device = torch::Device(torch::kCUDA, device_index);
|
|
||||||
})
|
|
||||||
.def("__str__", [](const FbankOptions &self) -> std::string {
|
.def("__str__", [](const FbankOptions &self) -> std::string {
|
||||||
return self.ToString();
|
return self.ToString();
|
||||||
});
|
});
|
||||||
|
@ -27,7 +27,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
|
|||||||
PybindMelBanksOptions(m);
|
PybindMelBanksOptions(m);
|
||||||
PybindFbankOptions(m);
|
PybindFbankOptions(m);
|
||||||
|
|
||||||
m.def("compute", &Compute, py::arg("wave"), py::arg("fbank"));
|
m.def("compute_fbank_feats", &Compute, py::arg("wave"), py::arg("fbank"));
|
||||||
|
|
||||||
// It verifies that the reimplementation produces the same output
|
// It verifies that the reimplementation produces the same output
|
||||||
// as kaldi using default parameters with dither disabled.
|
// as kaldi using default parameters with dither disabled.
|
||||||
|
@ -0,0 +1,3 @@
|
|||||||
|
from _kaldifeat import FbankOptions, FrameExtractionOptions, MelBanksOptions
|
||||||
|
|
||||||
|
from .fbank import Fbank
|
82
kaldifeat/python/kaldifeat/fbank.py
Normal file
82
kaldifeat/python/kaldifeat/fbank.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import _kaldifeat
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class Fbank(nn.Module):
|
||||||
|
def __init__(self, opts: _kaldifeat.FbankOptions):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.opts = opts
|
||||||
|
self.computer = _kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, waves: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
"""Compute the fbank features of a single waveform or
|
||||||
|
a list of waveforms.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
waves:
|
||||||
|
A single 1-D tensor or a list of 1-D tensors. Each tensor contains
|
||||||
|
audio samples of a soundfile. To get a result compatible with Kaldi,
|
||||||
|
you should scale the samples to [-32768, 32767] before calling this
|
||||||
|
function. Note: You are not required to scale them if you don't care
|
||||||
|
about the compatibility with Kaldi.
|
||||||
|
Returns:
|
||||||
|
Return a list of 2-D tensors containing the fbank features if the
|
||||||
|
input is a list of 1-D tensors. The returned list has as many elements
|
||||||
|
as the input list.
|
||||||
|
Return a single 2-D tensor if the input is a single tensor.
|
||||||
|
"""
|
||||||
|
if isinstance(waves, list):
|
||||||
|
is_list = True
|
||||||
|
else:
|
||||||
|
waves = [waves]
|
||||||
|
is_list = False
|
||||||
|
|
||||||
|
num_frames_per_wave = [
|
||||||
|
_kaldifeat.num_frames(w.numel(), self.opts.frame_opts)
|
||||||
|
for w in waves
|
||||||
|
]
|
||||||
|
|
||||||
|
strided = [self.convert_samples_to_frames(w) for w in waves]
|
||||||
|
strided = torch.cat(strided, dim=0)
|
||||||
|
|
||||||
|
features = self.compute(strided)
|
||||||
|
|
||||||
|
if is_list:
|
||||||
|
return list(features.split(num_frames_per_wave))
|
||||||
|
else:
|
||||||
|
return features
|
||||||
|
|
||||||
|
def compute(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute fbank features given a 2-D tensor containing
|
||||||
|
frames data. Each row is a frame of size frame_lens, specified
|
||||||
|
in the fbank options.
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 2-D tensor.
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor with as many rows as the input tensor. Its
|
||||||
|
number of columns is the number mel bins.
|
||||||
|
"""
|
||||||
|
features = _kaldifeat.compute_fbank_feats(x, self.computer)
|
||||||
|
return features
|
||||||
|
|
||||||
|
def convert_samples_to_frames(self, wave: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Convert a 1-D tensor containing audio samples to a 2-D
|
||||||
|
tensor where each row is a frame of samples of size frame length
|
||||||
|
specified in the fbank options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
waves:
|
||||||
|
A 1-D tensor.
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor.
|
||||||
|
"""
|
||||||
|
return _kaldifeat.get_strided(wave, self.opts.frame_opts)
|
@ -11,6 +11,12 @@ if [ ! -f test.wav ]; then
|
|||||||
sox -n -r 16000 -b 16 test.wav synth 1.2 sine 300-3300
|
sox -n -r 16000 -b 16 test.wav synth 1.2 sine 300-3300
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ ! -f test2.wav ]; then
|
||||||
|
# generate a wav of 0.5 seconds, containing a sine-wave
|
||||||
|
# swept from 300 Hz to 3300 Hz
|
||||||
|
sox -n -r 16000 -b 16 test2.wav synth 0.5 sine 300-3300
|
||||||
|
fi
|
||||||
|
|
||||||
echo "1 test.wav" > test.scp
|
echo "1 test.wav" > test.scp
|
||||||
|
|
||||||
# We disable dither for testing
|
# We disable dither for testing
|
||||||
|
BIN
kaldifeat/python/tests/test_data/test2.wav
Normal file
BIN
kaldifeat/python/tests/test_data/test2.wav
Normal file
Binary file not shown.
72
kaldifeat/python/tests/test_fbank.py
Executable file
72
kaldifeat/python/tests/test_fbank.py
Executable file
@ -0,0 +1,72 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import kaldifeat
|
||||||
|
|
||||||
|
|
||||||
|
def read_wave(filename) -> torch.Tensor:
|
||||||
|
"""Read a wave file and return it as a 1-D tensor.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
You don't need to scale it to [-32768, 32767].
|
||||||
|
We use scaling here to follow the approach in Kaldi.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of a sound file.
|
||||||
|
Returns:
|
||||||
|
Return a 1-D tensor containing audio samples.
|
||||||
|
"""
|
||||||
|
with sf.SoundFile(filename) as sf_desc:
|
||||||
|
sampling_rate = sf_desc.samplerate
|
||||||
|
assert sampling_rate == 16000
|
||||||
|
data = sf_desc.read(dtype=np.float32, always_2d=False)
|
||||||
|
data *= 32768
|
||||||
|
return torch.from_numpy(data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fbank():
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
wave0 = read_wave("test_data/test.wav")
|
||||||
|
wave1 = read_wave("test_data/test2.wav")
|
||||||
|
|
||||||
|
wave0 = wave0.to(device)
|
||||||
|
wave1 = wave1.to(device)
|
||||||
|
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.device = device
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
# We can compute fbank features in batches
|
||||||
|
features = fbank([wave0, wave1])
|
||||||
|
assert isinstance(features, list), f"{type(features)}"
|
||||||
|
assert len(features) == 2
|
||||||
|
|
||||||
|
# We can also compute fbank features for a single wave
|
||||||
|
features0 = fbank(wave0)
|
||||||
|
features1 = fbank(wave1)
|
||||||
|
|
||||||
|
assert torch.allclose(features[0], features0)
|
||||||
|
assert torch.allclose(features[1], features1)
|
||||||
|
|
||||||
|
# To compute fbank features for only a specified frame
|
||||||
|
audio_frames = fbank.convert_samples_to_frames(wave0)
|
||||||
|
feature_frame_1 = fbank.compute(audio_frames[1])
|
||||||
|
feature_frame_10 = fbank.compute(audio_frames[10])
|
||||||
|
|
||||||
|
assert torch.allclose(features0[1], feature_frame_1)
|
||||||
|
assert torch.allclose(features0[10], feature_frame_10)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_fbank()
|
@ -52,7 +52,7 @@ def test_and_benchmark_default_parameters():
|
|||||||
for device in devices:
|
for device in devices:
|
||||||
fbank_opts = _kaldifeat.FbankOptions()
|
fbank_opts = _kaldifeat.FbankOptions()
|
||||||
fbank_opts.frame_opts.dither = 0
|
fbank_opts.frame_opts.dither = 0
|
||||||
fbank_opts.set_device(device)
|
fbank_opts.device = device
|
||||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||||
|
|
||||||
data = read_wave().to(device)
|
data = read_wave().to(device)
|
||||||
@ -74,14 +74,14 @@ def test_use_energy_htk_compat_true():
|
|||||||
for device in devices:
|
for device in devices:
|
||||||
fbank_opts = _kaldifeat.FbankOptions()
|
fbank_opts = _kaldifeat.FbankOptions()
|
||||||
fbank_opts.frame_opts.dither = 0
|
fbank_opts.frame_opts.dither = 0
|
||||||
fbank_opts.set_device(device)
|
fbank_opts.device = device
|
||||||
fbank_opts.use_energy = True
|
fbank_opts.use_energy = True
|
||||||
fbank_opts.htk_compat = True
|
fbank_opts.htk_compat = True
|
||||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||||
|
|
||||||
data = read_wave().to(device)
|
data = read_wave().to(device)
|
||||||
|
|
||||||
ans = _kaldifeat.compute(data, fbank)
|
ans = _kaldifeat.compute_fbank_feats(data, fbank)
|
||||||
|
|
||||||
expected = read_ark_txt("test-htk.txt")
|
expected = read_ark_txt("test-htk.txt")
|
||||||
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
|
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
|
||||||
@ -97,12 +97,12 @@ def test_use_energy_htk_compat_false():
|
|||||||
fbank_opts.frame_opts.dither = 0
|
fbank_opts.frame_opts.dither = 0
|
||||||
fbank_opts.use_energy = True
|
fbank_opts.use_energy = True
|
||||||
fbank_opts.htk_compat = False
|
fbank_opts.htk_compat = False
|
||||||
fbank_opts.set_device(device)
|
fbank_opts.device = device
|
||||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||||
|
|
||||||
data = read_wave().to(device)
|
data = read_wave().to(device)
|
||||||
|
|
||||||
ans = _kaldifeat.compute(data, fbank)
|
ans = _kaldifeat.compute_fbank_feats(data, fbank)
|
||||||
|
|
||||||
expected = read_ark_txt("test-with-energy.txt")
|
expected = read_ark_txt("test-with-energy.txt")
|
||||||
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
|
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
|
||||||
@ -117,12 +117,12 @@ def test_40_mel():
|
|||||||
fbank_opts = _kaldifeat.FbankOptions()
|
fbank_opts = _kaldifeat.FbankOptions()
|
||||||
fbank_opts.frame_opts.dither = 0
|
fbank_opts.frame_opts.dither = 0
|
||||||
fbank_opts.mel_opts.num_bins = 40
|
fbank_opts.mel_opts.num_bins = 40
|
||||||
fbank_opts.set_device(device)
|
fbank_opts.device = device
|
||||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||||
|
|
||||||
data = read_wave().to(device)
|
data = read_wave().to(device)
|
||||||
|
|
||||||
ans = _kaldifeat.compute(data, fbank)
|
ans = _kaldifeat.compute_fbank_feats(data, fbank)
|
||||||
|
|
||||||
expected = read_ark_txt("test-40.txt")
|
expected = read_ark_txt("test-40.txt")
|
||||||
assert torch.allclose(ans.cpu(), expected, rtol=1e-1)
|
assert torch.allclose(ans.cpu(), expected, rtol=1e-1)
|
||||||
@ -138,12 +138,12 @@ def test_40_mel_no_snip_edges():
|
|||||||
fbank_opts.frame_opts.snip_edges = False
|
fbank_opts.frame_opts.snip_edges = False
|
||||||
fbank_opts.frame_opts.dither = 0
|
fbank_opts.frame_opts.dither = 0
|
||||||
fbank_opts.mel_opts.num_bins = 40
|
fbank_opts.mel_opts.num_bins = 40
|
||||||
fbank_opts.set_device(device)
|
fbank_opts.device = device
|
||||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||||
|
|
||||||
data = read_wave().to(device)
|
data = read_wave().to(device)
|
||||||
|
|
||||||
ans = _kaldifeat.compute(data, fbank)
|
ans = _kaldifeat.compute_fbank_feats(data, fbank)
|
||||||
|
|
||||||
expected = read_ark_txt("test-40-no-snip-edges.txt")
|
expected = read_ark_txt("test-40-no-snip-edges.txt")
|
||||||
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
|
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
|
||||||
@ -161,7 +161,7 @@ def test_compute_batch():
|
|||||||
fbank_opts = _kaldifeat.FbankOptions()
|
fbank_opts = _kaldifeat.FbankOptions()
|
||||||
fbank_opts.frame_opts.dither = 0
|
fbank_opts.frame_opts.dither = 0
|
||||||
fbank_opts.frame_opts.snip_edges = False
|
fbank_opts.frame_opts.snip_edges = False
|
||||||
fbank_opts.set_device(device)
|
fbank_opts.device = device
|
||||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||||
|
|
||||||
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]:
|
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||||
@ -175,7 +175,9 @@ def test_compute_batch():
|
|||||||
]
|
]
|
||||||
strided = torch.cat(strided, dim=0)
|
strided = torch.cat(strided, dim=0)
|
||||||
|
|
||||||
features = _kaldifeat.compute(strided, fbank).split(num_frames)
|
features = _kaldifeat.compute_fbank_feats(strided, fbank).split(
|
||||||
|
num_frames
|
||||||
|
)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
@ -52,6 +52,7 @@ def test_fbank_options():
|
|||||||
opts.use_energy = False
|
opts.use_energy = False
|
||||||
opts.use_log_fbank = True
|
opts.use_log_fbank = True
|
||||||
opts.use_power = True
|
opts.use_power = True
|
||||||
|
opts.device = "cuda:0"
|
||||||
|
|
||||||
frame_opts.blackman_coeff = 0.42
|
frame_opts.blackman_coeff = 0.42
|
||||||
frame_opts.dither = 1
|
frame_opts.dither = 1
|
||||||
@ -75,8 +76,8 @@ def test_fbank_options():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# test_frame_extraction_options()
|
test_frame_extraction_options()
|
||||||
# test_mel_banks_options()
|
test_mel_banks_options()
|
||||||
test_fbank_options()
|
test_fbank_options()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user