From f909f839abe2d3d2375bcc9f7963f7ac6e7bea40 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 27 Feb 2021 23:40:47 +0800 Subject: [PATCH] support torch.device --- kaldifeat/csrc/feature-common.h | 2 +- kaldifeat/csrc/feature-fbank.cc | 3 +- kaldifeat/csrc/feature-fbank.h | 6 +- kaldifeat/csrc/feature-window.cc | 7 +- kaldifeat/csrc/feature-window.h | 3 +- kaldifeat/csrc/mel-computations.cc | 6 +- kaldifeat/csrc/mel-computations.h | 3 +- kaldifeat/python/csrc/feature-fbank.cc | 16 ++++ kaldifeat/python/tests/test_kaldifeat.py | 99 ++++++++++++++---------- 9 files changed, 97 insertions(+), 48 deletions(-) diff --git a/kaldifeat/csrc/feature-common.h b/kaldifeat/csrc/feature-common.h index ad72d4e..1949eed 100644 --- a/kaldifeat/csrc/feature-common.h +++ b/kaldifeat/csrc/feature-common.h @@ -20,7 +20,7 @@ class OfflineFeatureTpl { // using the options class, that we cache at this level. OfflineFeatureTpl(const Options &opts) : computer_(opts), - feature_window_function_(computer_.GetFrameOptions()) {} + feature_window_function_(computer_.GetFrameOptions(), opts.device) {} /** Computes the features for one file (one sequence of features). diff --git a/kaldifeat/csrc/feature-fbank.cc b/kaldifeat/csrc/feature-fbank.cc index b5ad5a0..f6bb892 100644 --- a/kaldifeat/csrc/feature-fbank.cc +++ b/kaldifeat/csrc/feature-fbank.cc @@ -37,7 +37,8 @@ const MelBanks *FbankComputer::GetMelBanks(float vtln_warp) { // std::map::iterator iter = mel_banks_.find(vtln_warp); auto iter = mel_banks_.find(vtln_warp); if (iter == mel_banks_.end()) { - this_mel_banks = new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp); + this_mel_banks = + new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp, opts_.device); mel_banks_[vtln_warp] = this_mel_banks; } else { this_mel_banks = iter->second; diff --git a/kaldifeat/csrc/feature-fbank.h b/kaldifeat/csrc/feature-fbank.h index d3d686a..01db278 100644 --- a/kaldifeat/csrc/feature-fbank.h +++ b/kaldifeat/csrc/feature-fbank.h @@ -12,6 +12,7 @@ #include "kaldifeat/csrc/feature-common.h" #include "kaldifeat/csrc/feature-window.h" #include "kaldifeat/csrc/mel-computations.h" +#include "torch/torch.h" namespace kaldifeat { @@ -37,7 +38,9 @@ struct FbankOptions { // analysis, else magnitude. bool use_power = true; - FbankOptions() { mel_opts.num_bins = 23; } + torch::Device device; + + FbankOptions() : device("cpu") { mel_opts.num_bins = 23; } std::string ToString() const { std::ostringstream os; @@ -54,6 +57,7 @@ struct FbankOptions { os << "htk_compat: " << htk_compat << "\n"; os << "use_log_fbank: " << use_log_fbank << "\n"; os << "use_power: " << use_power << "\n"; + os << "device: " << device << "\n"; return os.str(); } }; diff --git a/kaldifeat/csrc/feature-window.cc b/kaldifeat/csrc/feature-window.cc index bcb4cd9..a5c10d0 100644 --- a/kaldifeat/csrc/feature-window.cc +++ b/kaldifeat/csrc/feature-window.cc @@ -21,8 +21,8 @@ std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts) { return os; } -FeatureWindowFunction::FeatureWindowFunction( - const FrameExtractionOptions &opts) { +FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts, + torch::Device device) { int32_t frame_length = opts.WindowSize(); KALDIFEAT_ASSERT(frame_length > 0); @@ -54,6 +54,9 @@ FeatureWindowFunction::FeatureWindowFunction( } window = window.unsqueeze(0); + if (window.device() != device) { + window = window.to(device); + } } void FeatureWindowFunction::Apply(torch::Tensor *wave) const { diff --git a/kaldifeat/csrc/feature-window.h b/kaldifeat/csrc/feature-window.h index 09caa5d..990f847 100644 --- a/kaldifeat/csrc/feature-window.h +++ b/kaldifeat/csrc/feature-window.h @@ -79,7 +79,8 @@ std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts); class FeatureWindowFunction { public: FeatureWindowFunction() = default; - explicit FeatureWindowFunction(const FrameExtractionOptions &opts); + FeatureWindowFunction(const FrameExtractionOptions &opts, + torch::Device device); void Apply(torch::Tensor *wave) const; private: diff --git a/kaldifeat/csrc/mel-computations.cc b/kaldifeat/csrc/mel-computations.cc index 22f01eb..acf64c4 100644 --- a/kaldifeat/csrc/mel-computations.cc +++ b/kaldifeat/csrc/mel-computations.cc @@ -88,7 +88,7 @@ float MelBanks::VtlnWarpMelFreq( MelBanks::MelBanks(const MelBanksOptions &opts, const FrameExtractionOptions &frame_opts, - float vtln_warp_factor) + float vtln_warp_factor, torch::Device device) : htk_mode_(opts.htk_mode) { int32_t num_bins = opts.num_bins; if (num_bins < 3) KALDIFEAT_ERR << "Must have at least 3 mel bins"; @@ -182,6 +182,10 @@ MelBanks::MelBanks(const MelBanksOptions &opts, if (debug_) KALDIFEAT_LOG << bins_mat_; bins_mat_.t_(); + + if (bins_mat_.device() != device) { + bins_mat_ = bins_mat_.to(device); + } } torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const { diff --git a/kaldifeat/csrc/mel-computations.h b/kaldifeat/csrc/mel-computations.h index 775ba9a..7b3cd5a 100644 --- a/kaldifeat/csrc/mel-computations.h +++ b/kaldifeat/csrc/mel-computations.h @@ -71,7 +71,8 @@ class MelBanks { float vtln_warp_factor, float mel_freq); MelBanks(const MelBanksOptions &opts, - const FrameExtractionOptions &frame_opts, float vtln_warp_factor); + const FrameExtractionOptions &frame_opts, float vtln_warp_factor, + torch::Device device); int32_t NumBins() const { return static_cast(bins_mat_.sizes()[0]); } diff --git a/kaldifeat/python/csrc/feature-fbank.cc b/kaldifeat/python/csrc/feature-fbank.cc index f8c0a78..29b7f78 100644 --- a/kaldifeat/python/csrc/feature-fbank.cc +++ b/kaldifeat/python/csrc/feature-fbank.cc @@ -19,6 +19,22 @@ void PybindFbankOptions(py::module &m) { .def_readwrite("htk_compat", &FbankOptions::htk_compat) .def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank) .def_readwrite("use_power", &FbankOptions::use_power) + .def("set_device", + [](FbankOptions *fbank_opts, py::object device) { + std::string device_type = + static_cast(device.attr("type")); + KALDIFEAT_ASSERT(device_type == "cpu" || device_type == "cuda") + << "Unsupported device type: " << device_type; + + auto index_attr = static_cast(device.attr("index")); + int32_t device_index = 0; + if (!index_attr.is_none()) + device_index = static_cast(index_attr); + if (device_type == "cpu") + fbank_opts->device = torch::Device("cpu"); + else + fbank_opts->device = torch::Device(torch::kCUDA, device_index); + }) .def("__str__", [](const FbankOptions &self) -> std::string { return self.ToString(); }); diff --git a/kaldifeat/python/tests/test_kaldifeat.py b/kaldifeat/python/tests/test_kaldifeat.py index 3e0c24e..0c970cc 100755 --- a/kaldifeat/python/tests/test_kaldifeat.py +++ b/kaldifeat/python/tests/test_kaldifeat.py @@ -56,61 +56,80 @@ def read_wave() -> torch.Tensor: def test_and_benchmark_default_parameters(): - fbank_opts = _kaldifeat.FbankOptions() - fbank_opts.frame_opts.dither = 0 - fbank = _kaldifeat.Fbank(fbank_opts) + devices = [torch.device('cpu')] + if torch.cuda.is_available(): + devices.append(torch.device('cuda', 0)) - data = read_wave() + for device in devices: + fbank_opts = _kaldifeat.FbankOptions() + fbank_opts.frame_opts.dither = 0 + fbank_opts.set_device(device) + fbank = _kaldifeat.Fbank(fbank_opts) - ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(data, fbank) + data = read_wave().to(device) - expected = read_ark_txt() - assert torch.allclose(ans, expected, rtol=1e-3) - print('elapsed seconds:', elapsed_seconds) + ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time( + data, fbank) + + expected = read_ark_txt() + assert torch.allclose(ans.cpu(), expected, rtol=1e-3) + print(f'elapsed seconds {device}:', elapsed_seconds) def test_use_energy_htk_compat_true(): - fbank_opts = _kaldifeat.FbankOptions() - fbank_opts.frame_opts.dither = 0 - fbank_opts.use_energy = True - fbank_opts.htk_compat = True - fbank = _kaldifeat.Fbank(fbank_opts) + devices = [torch.device('cpu')] + if torch.cuda.is_available(): + devices.append(torch.device('cuda', 0)) - data = read_wave() + for device in devices: + fbank_opts = _kaldifeat.FbankOptions() + fbank_opts.frame_opts.dither = 0 + fbank_opts.set_device(device) + fbank_opts.use_energy = True + fbank_opts.htk_compat = True + fbank = _kaldifeat.Fbank(fbank_opts) - ans = _kaldifeat.compute(data, fbank) + data = read_wave().to(device) - # ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt - # the first 3 rows are: - expected_str = ''' - 15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303 25.38995 - 15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113 25.38996 - 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 25.38995 - ''' - expected = parse_str(expected_str) - assert torch.allclose(ans[:3, :], expected, rtol=1e-3) + ans = _kaldifeat.compute(data, fbank) + + # ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt + # the first 3 rows are: + expected_str = ''' + 15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303 25.38995 + 15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113 25.38996 + 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 25.38995 + ''' + expected = parse_str(expected_str) + assert torch.allclose(ans[:3, :].cpu(), expected, rtol=1e-3) def test_use_energy_htk_compat_false(): - fbank_opts = _kaldifeat.FbankOptions() - fbank_opts.frame_opts.dither = 0 - fbank_opts.use_energy = True - fbank_opts.htk_compat = False - fbank = _kaldifeat.Fbank(fbank_opts) + devices = [torch.device('cpu')] + if torch.cuda.is_available(): + devices.append(torch.device('cuda', 0)) - data = read_wave() + for device in devices: + fbank_opts = _kaldifeat.FbankOptions() + fbank_opts.frame_opts.dither = 0 + fbank_opts.use_energy = True + fbank_opts.htk_compat = False + fbank_opts.set_device(device) + fbank = _kaldifeat.Fbank(fbank_opts) - ans = _kaldifeat.compute(data, fbank) + data = read_wave().to(device) - # ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt - # the first 3 rows are: - expected_str = ''' - 25.38995 15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303 - 25.38996 15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113 - 25.38995 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 - ''' - expected = parse_str(expected_str) - assert torch.allclose(ans[:3, :], expected, rtol=1e-3) + ans = _kaldifeat.compute(data, fbank) + + # ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt + # the first 3 rows are: + expected_str = ''' + 25.38995 15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303 + 25.38996 15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113 + 25.38995 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 + ''' + expected = parse_str(expected_str) + assert torch.allclose(ans[:3, :].cpu(), expected, rtol=1e-3) def main():