diff --git a/.flake8 b/.flake8 index 0e88669..3551e08 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,9 @@ [flake8] max-line-length = 80 +exclude = + .git, + kaldifeat/python/kaldifeat/__init__.py + ignore = E402 diff --git a/.gitignore b/.gitignore index f9f78d8..c697d52 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ build/ build*/ *.egg-info*/ dist/ +__pycache__/ diff --git a/kaldifeat/csrc/CMakeLists.txt b/kaldifeat/csrc/CMakeLists.txt index eddab10..d94d43a 100644 --- a/kaldifeat/csrc/CMakeLists.txt +++ b/kaldifeat/csrc/CMakeLists.txt @@ -9,6 +9,12 @@ set(kaldifeat_srcs add_library(kaldifeat_core SHARED ${kaldifeat_srcs}) target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES}) +# PYTHON_INCLUDE_DIRS is set by pybind11 +target_include_directories(kaldifeat_core PUBLIC ${PYTHON_INCLUDE_DIRS}) + +# PYTHON_LIBRARY is set by pybind11 +target_link_libraries(kaldifeat_core PUBLIC ${PYTHON_LIBRARY}) + add_executable(test_kaldifeat test_kaldifeat.cc) target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core) diff --git a/kaldifeat/csrc/feature-fbank.h b/kaldifeat/csrc/feature-fbank.h index fc40359..80a3ba9 100644 --- a/kaldifeat/csrc/feature-fbank.h +++ b/kaldifeat/csrc/feature-fbank.h @@ -8,12 +8,16 @@ #define KALDIFEAT_CSRC_FEATURE_FBANK_H_ #include +#include #include "kaldifeat/csrc/feature-common.h" #include "kaldifeat/csrc/feature-window.h" #include "kaldifeat/csrc/mel-computations.h" +#include "pybind11/pybind11.h" #include "torch/torch.h" +namespace py = pybind11; + namespace kaldifeat { struct FbankOptions { @@ -42,6 +46,16 @@ struct FbankOptions { FbankOptions() : device("cpu") { mel_opts.num_bins = 23; } + // Get/Set methods are for implementing properties in Python + py::object GetDevice() const { + py::object ans = py::module_::import("torch").attr("device"); + return ans(device.str()); + } + void SetDevice(py::object obj) { + std::string s = static_cast(obj); + device = torch::Device(s); + } + std::string ToString() const { std::ostringstream os; os << "frame_opts: \n"; diff --git a/kaldifeat/python/csrc/feature-fbank.cc b/kaldifeat/python/csrc/feature-fbank.cc index 5122605..5f26a6b 100644 --- a/kaldifeat/python/csrc/feature-fbank.cc +++ b/kaldifeat/python/csrc/feature-fbank.cc @@ -19,22 +19,8 @@ void PybindFbankOptions(py::module &m) { .def_readwrite("htk_compat", &FbankOptions::htk_compat) .def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank) .def_readwrite("use_power", &FbankOptions::use_power) - .def("set_device", - [](FbankOptions *fbank_opts, py::object device) { - std::string device_type = - static_cast(device.attr("type")); - KALDIFEAT_ASSERT(device_type == "cpu" || device_type == "cuda") - << "Unsupported device type: " << device_type; - - auto index_attr = static_cast(device.attr("index")); - int32_t device_index = 0; - if (!index_attr.is_none()) - device_index = static_cast(index_attr); - if (device_type == "cpu") - fbank_opts->device = torch::Device("cpu"); - else - fbank_opts->device = torch::Device(torch::kCUDA, device_index); - }) + .def_property("device", &FbankOptions::GetDevice, + &FbankOptions::SetDevice) .def("__str__", [](const FbankOptions &self) -> std::string { return self.ToString(); }); diff --git a/kaldifeat/python/csrc/kaldifeat.cc b/kaldifeat/python/csrc/kaldifeat.cc index 3398f8d..ca4bd79 100644 --- a/kaldifeat/python/csrc/kaldifeat.cc +++ b/kaldifeat/python/csrc/kaldifeat.cc @@ -27,7 +27,7 @@ PYBIND11_MODULE(_kaldifeat, m) { PybindMelBanksOptions(m); PybindFbankOptions(m); - m.def("compute", &Compute, py::arg("wave"), py::arg("fbank")); + m.def("compute_fbank_feats", &Compute, py::arg("wave"), py::arg("fbank")); // It verifies that the reimplementation produces the same output // as kaldi using default parameters with dither disabled. diff --git a/kaldifeat/python/kaldifeat/__init__.py b/kaldifeat/python/kaldifeat/__init__.py index e69de29..e177288 100644 --- a/kaldifeat/python/kaldifeat/__init__.py +++ b/kaldifeat/python/kaldifeat/__init__.py @@ -0,0 +1,3 @@ +from _kaldifeat import FbankOptions, FrameExtractionOptions, MelBanksOptions + +from .fbank import Fbank diff --git a/kaldifeat/python/kaldifeat/fbank.py b/kaldifeat/python/kaldifeat/fbank.py new file mode 100644 index 0000000..5196956 --- /dev/null +++ b/kaldifeat/python/kaldifeat/fbank.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +from typing import List, Union + +import _kaldifeat +import torch +import torch.nn as nn + + +class Fbank(nn.Module): + def __init__(self, opts: _kaldifeat.FbankOptions): + super().__init__() + + self.opts = opts + self.computer = _kaldifeat.Fbank(opts) + + def forward( + self, waves: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Compute the fbank features of a single waveform or + a list of waveforms. + + Args: + waves: + A single 1-D tensor or a list of 1-D tensors. Each tensor contains + audio samples of a soundfile. To get a result compatible with Kaldi, + you should scale the samples to [-32768, 32767] before calling this + function. Note: You are not required to scale them if you don't care + about the compatibility with Kaldi. + Returns: + Return a list of 2-D tensors containing the fbank features if the + input is a list of 1-D tensors. The returned list has as many elements + as the input list. + Return a single 2-D tensor if the input is a single tensor. + """ + if isinstance(waves, list): + is_list = True + else: + waves = [waves] + is_list = False + + num_frames_per_wave = [ + _kaldifeat.num_frames(w.numel(), self.opts.frame_opts) + for w in waves + ] + + strided = [self.convert_samples_to_frames(w) for w in waves] + strided = torch.cat(strided, dim=0) + + features = self.compute(strided) + + if is_list: + return list(features.split(num_frames_per_wave)) + else: + return features + + def compute(self, x: torch.Tensor) -> torch.Tensor: + """Compute fbank features given a 2-D tensor containing + frames data. Each row is a frame of size frame_lens, specified + in the fbank options. + Args: + x: + A 2-D tensor. + Returns: + Return a 2-D tensor with as many rows as the input tensor. Its + number of columns is the number mel bins. + """ + features = _kaldifeat.compute_fbank_feats(x, self.computer) + return features + + def convert_samples_to_frames(self, wave: torch.Tensor) -> torch.Tensor: + """Convert a 1-D tensor containing audio samples to a 2-D + tensor where each row is a frame of samples of size frame length + specified in the fbank options. + + Args: + waves: + A 1-D tensor. + Returns: + Return a 2-D tensor. + """ + return _kaldifeat.get_strided(wave, self.opts.frame_opts) diff --git a/kaldifeat/python/tests/test_data/run.sh b/kaldifeat/python/tests/test_data/run.sh index 53b3f39..361b657 100755 --- a/kaldifeat/python/tests/test_data/run.sh +++ b/kaldifeat/python/tests/test_data/run.sh @@ -11,6 +11,12 @@ if [ ! -f test.wav ]; then sox -n -r 16000 -b 16 test.wav synth 1.2 sine 300-3300 fi +if [ ! -f test2.wav ]; then + # generate a wav of 0.5 seconds, containing a sine-wave + # swept from 300 Hz to 3300 Hz + sox -n -r 16000 -b 16 test2.wav synth 0.5 sine 300-3300 +fi + echo "1 test.wav" > test.scp # We disable dither for testing diff --git a/kaldifeat/python/tests/test_data/test2.wav b/kaldifeat/python/tests/test_data/test2.wav new file mode 100644 index 0000000..0016f50 Binary files /dev/null and b/kaldifeat/python/tests/test_data/test2.wav differ diff --git a/kaldifeat/python/tests/test_fbank.py b/kaldifeat/python/tests/test_fbank.py new file mode 100755 index 0000000..0f39a1c --- /dev/null +++ b/kaldifeat/python/tests/test_fbank.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +import numpy as np +import soundfile as sf +import torch + +import kaldifeat + + +def read_wave(filename) -> torch.Tensor: + """Read a wave file and return it as a 1-D tensor. + + Note: + You don't need to scale it to [-32768, 32767]. + We use scaling here to follow the approach in Kaldi. + + Args: + filename: + Filename of a sound file. + Returns: + Return a 1-D tensor containing audio samples. + """ + with sf.SoundFile(filename) as sf_desc: + sampling_rate = sf_desc.samplerate + assert sampling_rate == 16000 + data = sf_desc.read(dtype=np.float32, always_2d=False) + data *= 32768 + return torch.from_numpy(data) + + +def test_fbank(): + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + wave0 = read_wave("test_data/test.wav") + wave1 = read_wave("test_data/test2.wav") + + wave0 = wave0.to(device) + wave1 = wave1.to(device) + + opts = kaldifeat.FbankOptions() + opts.frame_opts.dither = 0 + opts.device = device + + fbank = kaldifeat.Fbank(opts) + + # We can compute fbank features in batches + features = fbank([wave0, wave1]) + assert isinstance(features, list), f"{type(features)}" + assert len(features) == 2 + + # We can also compute fbank features for a single wave + features0 = fbank(wave0) + features1 = fbank(wave1) + + assert torch.allclose(features[0], features0) + assert torch.allclose(features[1], features1) + + # To compute fbank features for only a specified frame + audio_frames = fbank.convert_samples_to_frames(wave0) + feature_frame_1 = fbank.compute(audio_frames[1]) + feature_frame_10 = fbank.compute(audio_frames[10]) + + assert torch.allclose(features0[1], feature_frame_1) + assert torch.allclose(features0[10], feature_frame_10) + + +if __name__ == "__main__": + test_fbank() diff --git a/kaldifeat/python/tests/test_kaldifeat.py b/kaldifeat/python/tests/test_kaldifeat.py index c41712c..bcf14c8 100755 --- a/kaldifeat/python/tests/test_kaldifeat.py +++ b/kaldifeat/python/tests/test_kaldifeat.py @@ -52,7 +52,7 @@ def test_and_benchmark_default_parameters(): for device in devices: fbank_opts = _kaldifeat.FbankOptions() fbank_opts.frame_opts.dither = 0 - fbank_opts.set_device(device) + fbank_opts.device = device fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave().to(device) @@ -74,14 +74,14 @@ def test_use_energy_htk_compat_true(): for device in devices: fbank_opts = _kaldifeat.FbankOptions() fbank_opts.frame_opts.dither = 0 - fbank_opts.set_device(device) + fbank_opts.device = device fbank_opts.use_energy = True fbank_opts.htk_compat = True fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave().to(device) - ans = _kaldifeat.compute(data, fbank) + ans = _kaldifeat.compute_fbank_feats(data, fbank) expected = read_ark_txt("test-htk.txt") assert torch.allclose(ans.cpu(), expected, rtol=1e-2) @@ -97,12 +97,12 @@ def test_use_energy_htk_compat_false(): fbank_opts.frame_opts.dither = 0 fbank_opts.use_energy = True fbank_opts.htk_compat = False - fbank_opts.set_device(device) + fbank_opts.device = device fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave().to(device) - ans = _kaldifeat.compute(data, fbank) + ans = _kaldifeat.compute_fbank_feats(data, fbank) expected = read_ark_txt("test-with-energy.txt") assert torch.allclose(ans.cpu(), expected, rtol=1e-2) @@ -117,12 +117,12 @@ def test_40_mel(): fbank_opts = _kaldifeat.FbankOptions() fbank_opts.frame_opts.dither = 0 fbank_opts.mel_opts.num_bins = 40 - fbank_opts.set_device(device) + fbank_opts.device = device fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave().to(device) - ans = _kaldifeat.compute(data, fbank) + ans = _kaldifeat.compute_fbank_feats(data, fbank) expected = read_ark_txt("test-40.txt") assert torch.allclose(ans.cpu(), expected, rtol=1e-1) @@ -138,12 +138,12 @@ def test_40_mel_no_snip_edges(): fbank_opts.frame_opts.snip_edges = False fbank_opts.frame_opts.dither = 0 fbank_opts.mel_opts.num_bins = 40 - fbank_opts.set_device(device) + fbank_opts.device = device fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave().to(device) - ans = _kaldifeat.compute(data, fbank) + ans = _kaldifeat.compute_fbank_feats(data, fbank) expected = read_ark_txt("test-40-no-snip-edges.txt") assert torch.allclose(ans.cpu(), expected, rtol=1e-2) @@ -161,7 +161,7 @@ def test_compute_batch(): fbank_opts = _kaldifeat.FbankOptions() fbank_opts.frame_opts.dither = 0 fbank_opts.frame_opts.snip_edges = False - fbank_opts.set_device(device) + fbank_opts.device = device fbank = _kaldifeat.Fbank(fbank_opts) def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]: @@ -175,7 +175,9 @@ def test_compute_batch(): ] strided = torch.cat(strided, dim=0) - features = _kaldifeat.compute(strided, fbank).split(num_frames) + features = _kaldifeat.compute_fbank_feats(strided, fbank).split( + num_frames + ) return features diff --git a/kaldifeat/python/tests/test_options.py b/kaldifeat/python/tests/test_options.py index 41e660b..e4a01a7 100755 --- a/kaldifeat/python/tests/test_options.py +++ b/kaldifeat/python/tests/test_options.py @@ -52,6 +52,7 @@ def test_fbank_options(): opts.use_energy = False opts.use_log_fbank = True opts.use_power = True + opts.device = "cuda:0" frame_opts.blackman_coeff = 0.42 frame_opts.dither = 1 @@ -75,8 +76,8 @@ def test_fbank_options(): def main(): - # test_frame_extraction_options() - # test_mel_banks_options() + test_frame_extraction_options() + test_mel_banks_options() test_fbank_options()