mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 18:12:17 +00:00
Add kaldifeat.Fbank
This commit is contained in:
parent
be8442fd68
commit
e043afe3d6
4
.flake8
4
.flake8
@ -1,5 +1,9 @@
|
||||
[flake8]
|
||||
max-line-length = 80
|
||||
|
||||
exclude =
|
||||
.git,
|
||||
kaldifeat/python/kaldifeat/__init__.py
|
||||
|
||||
ignore =
|
||||
E402
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,3 +2,4 @@ build/
|
||||
build*/
|
||||
*.egg-info*/
|
||||
dist/
|
||||
__pycache__/
|
||||
|
@ -9,6 +9,12 @@ set(kaldifeat_srcs
|
||||
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
|
||||
target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES})
|
||||
|
||||
# PYTHON_INCLUDE_DIRS is set by pybind11
|
||||
target_include_directories(kaldifeat_core PUBLIC ${PYTHON_INCLUDE_DIRS})
|
||||
|
||||
# PYTHON_LIBRARY is set by pybind11
|
||||
target_link_libraries(kaldifeat_core PUBLIC ${PYTHON_LIBRARY})
|
||||
|
||||
add_executable(test_kaldifeat test_kaldifeat.cc)
|
||||
target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core)
|
||||
|
||||
|
@ -8,12 +8,16 @@
|
||||
#define KALDIFEAT_CSRC_FEATURE_FBANK_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "kaldifeat/csrc/feature-common.h"
|
||||
#include "kaldifeat/csrc/feature-window.h"
|
||||
#include "kaldifeat/csrc/mel-computations.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "torch/torch.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
struct FbankOptions {
|
||||
@ -42,6 +46,16 @@ struct FbankOptions {
|
||||
|
||||
FbankOptions() : device("cpu") { mel_opts.num_bins = 23; }
|
||||
|
||||
// Get/Set methods are for implementing properties in Python
|
||||
py::object GetDevice() const {
|
||||
py::object ans = py::module_::import("torch").attr("device");
|
||||
return ans(device.str());
|
||||
}
|
||||
void SetDevice(py::object obj) {
|
||||
std::string s = static_cast<py::str>(obj);
|
||||
device = torch::Device(s);
|
||||
}
|
||||
|
||||
std::string ToString() const {
|
||||
std::ostringstream os;
|
||||
os << "frame_opts: \n";
|
||||
|
@ -19,22 +19,8 @@ void PybindFbankOptions(py::module &m) {
|
||||
.def_readwrite("htk_compat", &FbankOptions::htk_compat)
|
||||
.def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank)
|
||||
.def_readwrite("use_power", &FbankOptions::use_power)
|
||||
.def("set_device",
|
||||
[](FbankOptions *fbank_opts, py::object device) {
|
||||
std::string device_type =
|
||||
static_cast<py::str>(device.attr("type"));
|
||||
KALDIFEAT_ASSERT(device_type == "cpu" || device_type == "cuda")
|
||||
<< "Unsupported device type: " << device_type;
|
||||
|
||||
auto index_attr = static_cast<py::object>(device.attr("index"));
|
||||
int32_t device_index = 0;
|
||||
if (!index_attr.is_none())
|
||||
device_index = static_cast<py::int_>(index_attr);
|
||||
if (device_type == "cpu")
|
||||
fbank_opts->device = torch::Device("cpu");
|
||||
else
|
||||
fbank_opts->device = torch::Device(torch::kCUDA, device_index);
|
||||
})
|
||||
.def_property("device", &FbankOptions::GetDevice,
|
||||
&FbankOptions::SetDevice)
|
||||
.def("__str__", [](const FbankOptions &self) -> std::string {
|
||||
return self.ToString();
|
||||
});
|
||||
|
@ -27,7 +27,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
|
||||
PybindMelBanksOptions(m);
|
||||
PybindFbankOptions(m);
|
||||
|
||||
m.def("compute", &Compute, py::arg("wave"), py::arg("fbank"));
|
||||
m.def("compute_fbank_feats", &Compute, py::arg("wave"), py::arg("fbank"));
|
||||
|
||||
// It verifies that the reimplementation produces the same output
|
||||
// as kaldi using default parameters with dither disabled.
|
||||
|
@ -0,0 +1,3 @@
|
||||
from _kaldifeat import FbankOptions, FrameExtractionOptions, MelBanksOptions
|
||||
|
||||
from .fbank import Fbank
|
82
kaldifeat/python/kaldifeat/fbank.py
Normal file
82
kaldifeat/python/kaldifeat/fbank.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
import _kaldifeat
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Fbank(nn.Module):
|
||||
def __init__(self, opts: _kaldifeat.FbankOptions):
|
||||
super().__init__()
|
||||
|
||||
self.opts = opts
|
||||
self.computer = _kaldifeat.Fbank(opts)
|
||||
|
||||
def forward(
|
||||
self, waves: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Compute the fbank features of a single waveform or
|
||||
a list of waveforms.
|
||||
|
||||
Args:
|
||||
waves:
|
||||
A single 1-D tensor or a list of 1-D tensors. Each tensor contains
|
||||
audio samples of a soundfile. To get a result compatible with Kaldi,
|
||||
you should scale the samples to [-32768, 32767] before calling this
|
||||
function. Note: You are not required to scale them if you don't care
|
||||
about the compatibility with Kaldi.
|
||||
Returns:
|
||||
Return a list of 2-D tensors containing the fbank features if the
|
||||
input is a list of 1-D tensors. The returned list has as many elements
|
||||
as the input list.
|
||||
Return a single 2-D tensor if the input is a single tensor.
|
||||
"""
|
||||
if isinstance(waves, list):
|
||||
is_list = True
|
||||
else:
|
||||
waves = [waves]
|
||||
is_list = False
|
||||
|
||||
num_frames_per_wave = [
|
||||
_kaldifeat.num_frames(w.numel(), self.opts.frame_opts)
|
||||
for w in waves
|
||||
]
|
||||
|
||||
strided = [self.convert_samples_to_frames(w) for w in waves]
|
||||
strided = torch.cat(strided, dim=0)
|
||||
|
||||
features = self.compute(strided)
|
||||
|
||||
if is_list:
|
||||
return list(features.split(num_frames_per_wave))
|
||||
else:
|
||||
return features
|
||||
|
||||
def compute(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute fbank features given a 2-D tensor containing
|
||||
frames data. Each row is a frame of size frame_lens, specified
|
||||
in the fbank options.
|
||||
Args:
|
||||
x:
|
||||
A 2-D tensor.
|
||||
Returns:
|
||||
Return a 2-D tensor with as many rows as the input tensor. Its
|
||||
number of columns is the number mel bins.
|
||||
"""
|
||||
features = _kaldifeat.compute_fbank_feats(x, self.computer)
|
||||
return features
|
||||
|
||||
def convert_samples_to_frames(self, wave: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert a 1-D tensor containing audio samples to a 2-D
|
||||
tensor where each row is a frame of samples of size frame length
|
||||
specified in the fbank options.
|
||||
|
||||
Args:
|
||||
waves:
|
||||
A 1-D tensor.
|
||||
Returns:
|
||||
Return a 2-D tensor.
|
||||
"""
|
||||
return _kaldifeat.get_strided(wave, self.opts.frame_opts)
|
@ -11,6 +11,12 @@ if [ ! -f test.wav ]; then
|
||||
sox -n -r 16000 -b 16 test.wav synth 1.2 sine 300-3300
|
||||
fi
|
||||
|
||||
if [ ! -f test2.wav ]; then
|
||||
# generate a wav of 0.5 seconds, containing a sine-wave
|
||||
# swept from 300 Hz to 3300 Hz
|
||||
sox -n -r 16000 -b 16 test2.wav synth 0.5 sine 300-3300
|
||||
fi
|
||||
|
||||
echo "1 test.wav" > test.scp
|
||||
|
||||
# We disable dither for testing
|
||||
|
BIN
kaldifeat/python/tests/test_data/test2.wav
Normal file
BIN
kaldifeat/python/tests/test_data/test2.wav
Normal file
Binary file not shown.
72
kaldifeat/python/tests/test_fbank.py
Executable file
72
kaldifeat/python/tests/test_fbank.py
Executable file
@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
import kaldifeat
|
||||
|
||||
|
||||
def read_wave(filename) -> torch.Tensor:
|
||||
"""Read a wave file and return it as a 1-D tensor.
|
||||
|
||||
Note:
|
||||
You don't need to scale it to [-32768, 32767].
|
||||
We use scaling here to follow the approach in Kaldi.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of a sound file.
|
||||
Returns:
|
||||
Return a 1-D tensor containing audio samples.
|
||||
"""
|
||||
with sf.SoundFile(filename) as sf_desc:
|
||||
sampling_rate = sf_desc.samplerate
|
||||
assert sampling_rate == 16000
|
||||
data = sf_desc.read(dtype=np.float32, always_2d=False)
|
||||
data *= 32768
|
||||
return torch.from_numpy(data)
|
||||
|
||||
|
||||
def test_fbank():
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
wave0 = read_wave("test_data/test.wav")
|
||||
wave1 = read_wave("test_data/test2.wav")
|
||||
|
||||
wave0 = wave0.to(device)
|
||||
wave1 = wave1.to(device)
|
||||
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.device = device
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
# We can compute fbank features in batches
|
||||
features = fbank([wave0, wave1])
|
||||
assert isinstance(features, list), f"{type(features)}"
|
||||
assert len(features) == 2
|
||||
|
||||
# We can also compute fbank features for a single wave
|
||||
features0 = fbank(wave0)
|
||||
features1 = fbank(wave1)
|
||||
|
||||
assert torch.allclose(features[0], features0)
|
||||
assert torch.allclose(features[1], features1)
|
||||
|
||||
# To compute fbank features for only a specified frame
|
||||
audio_frames = fbank.convert_samples_to_frames(wave0)
|
||||
feature_frame_1 = fbank.compute(audio_frames[1])
|
||||
feature_frame_10 = fbank.compute(audio_frames[10])
|
||||
|
||||
assert torch.allclose(features0[1], feature_frame_1)
|
||||
assert torch.allclose(features0[10], feature_frame_10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_fbank()
|
@ -52,7 +52,7 @@ def test_and_benchmark_default_parameters():
|
||||
for device in devices:
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank_opts.set_device(device)
|
||||
fbank_opts.device = device
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
|
||||
data = read_wave().to(device)
|
||||
@ -74,14 +74,14 @@ def test_use_energy_htk_compat_true():
|
||||
for device in devices:
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank_opts.set_device(device)
|
||||
fbank_opts.device = device
|
||||
fbank_opts.use_energy = True
|
||||
fbank_opts.htk_compat = True
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
|
||||
data = read_wave().to(device)
|
||||
|
||||
ans = _kaldifeat.compute(data, fbank)
|
||||
ans = _kaldifeat.compute_fbank_feats(data, fbank)
|
||||
|
||||
expected = read_ark_txt("test-htk.txt")
|
||||
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
|
||||
@ -97,12 +97,12 @@ def test_use_energy_htk_compat_false():
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank_opts.use_energy = True
|
||||
fbank_opts.htk_compat = False
|
||||
fbank_opts.set_device(device)
|
||||
fbank_opts.device = device
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
|
||||
data = read_wave().to(device)
|
||||
|
||||
ans = _kaldifeat.compute(data, fbank)
|
||||
ans = _kaldifeat.compute_fbank_feats(data, fbank)
|
||||
|
||||
expected = read_ark_txt("test-with-energy.txt")
|
||||
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
|
||||
@ -117,12 +117,12 @@ def test_40_mel():
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank_opts.mel_opts.num_bins = 40
|
||||
fbank_opts.set_device(device)
|
||||
fbank_opts.device = device
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
|
||||
data = read_wave().to(device)
|
||||
|
||||
ans = _kaldifeat.compute(data, fbank)
|
||||
ans = _kaldifeat.compute_fbank_feats(data, fbank)
|
||||
|
||||
expected = read_ark_txt("test-40.txt")
|
||||
assert torch.allclose(ans.cpu(), expected, rtol=1e-1)
|
||||
@ -138,12 +138,12 @@ def test_40_mel_no_snip_edges():
|
||||
fbank_opts.frame_opts.snip_edges = False
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank_opts.mel_opts.num_bins = 40
|
||||
fbank_opts.set_device(device)
|
||||
fbank_opts.device = device
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
|
||||
data = read_wave().to(device)
|
||||
|
||||
ans = _kaldifeat.compute(data, fbank)
|
||||
ans = _kaldifeat.compute_fbank_feats(data, fbank)
|
||||
|
||||
expected = read_ark_txt("test-40-no-snip-edges.txt")
|
||||
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
|
||||
@ -161,7 +161,7 @@ def test_compute_batch():
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank_opts.frame_opts.snip_edges = False
|
||||
fbank_opts.set_device(device)
|
||||
fbank_opts.device = device
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
|
||||
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
@ -175,7 +175,9 @@ def test_compute_batch():
|
||||
]
|
||||
strided = torch.cat(strided, dim=0)
|
||||
|
||||
features = _kaldifeat.compute(strided, fbank).split(num_frames)
|
||||
features = _kaldifeat.compute_fbank_feats(strided, fbank).split(
|
||||
num_frames
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
|
@ -52,6 +52,7 @@ def test_fbank_options():
|
||||
opts.use_energy = False
|
||||
opts.use_log_fbank = True
|
||||
opts.use_power = True
|
||||
opts.device = "cuda:0"
|
||||
|
||||
frame_opts.blackman_coeff = 0.42
|
||||
frame_opts.dither = 1
|
||||
@ -75,8 +76,8 @@ def test_fbank_options():
|
||||
|
||||
|
||||
def main():
|
||||
# test_frame_extraction_options()
|
||||
# test_mel_banks_options()
|
||||
test_frame_extraction_options()
|
||||
test_mel_banks_options()
|
||||
test_fbank_options()
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user