support torch.device

This commit is contained in:
Fangjun Kuang 2021-02-27 23:40:47 +08:00
parent b2980cdffd
commit f909f839ab
9 changed files with 97 additions and 48 deletions

View File

@ -20,7 +20,7 @@ class OfflineFeatureTpl {
// using the options class, that we cache at this level. // using the options class, that we cache at this level.
OfflineFeatureTpl(const Options &opts) OfflineFeatureTpl(const Options &opts)
: computer_(opts), : computer_(opts),
feature_window_function_(computer_.GetFrameOptions()) {} feature_window_function_(computer_.GetFrameOptions(), opts.device) {}
/** /**
Computes the features for one file (one sequence of features). Computes the features for one file (one sequence of features).

View File

@ -37,7 +37,8 @@ const MelBanks *FbankComputer::GetMelBanks(float vtln_warp) {
// std::map<float, MelBanks *>::iterator iter = mel_banks_.find(vtln_warp); // std::map<float, MelBanks *>::iterator iter = mel_banks_.find(vtln_warp);
auto iter = mel_banks_.find(vtln_warp); auto iter = mel_banks_.find(vtln_warp);
if (iter == mel_banks_.end()) { if (iter == mel_banks_.end()) {
this_mel_banks = new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp); this_mel_banks =
new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp, opts_.device);
mel_banks_[vtln_warp] = this_mel_banks; mel_banks_[vtln_warp] = this_mel_banks;
} else { } else {
this_mel_banks = iter->second; this_mel_banks = iter->second;

View File

@ -12,6 +12,7 @@
#include "kaldifeat/csrc/feature-common.h" #include "kaldifeat/csrc/feature-common.h"
#include "kaldifeat/csrc/feature-window.h" #include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/mel-computations.h" #include "kaldifeat/csrc/mel-computations.h"
#include "torch/torch.h"
namespace kaldifeat { namespace kaldifeat {
@ -37,7 +38,9 @@ struct FbankOptions {
// analysis, else magnitude. // analysis, else magnitude.
bool use_power = true; bool use_power = true;
FbankOptions() { mel_opts.num_bins = 23; } torch::Device device;
FbankOptions() : device("cpu") { mel_opts.num_bins = 23; }
std::string ToString() const { std::string ToString() const {
std::ostringstream os; std::ostringstream os;
@ -54,6 +57,7 @@ struct FbankOptions {
os << "htk_compat: " << htk_compat << "\n"; os << "htk_compat: " << htk_compat << "\n";
os << "use_log_fbank: " << use_log_fbank << "\n"; os << "use_log_fbank: " << use_log_fbank << "\n";
os << "use_power: " << use_power << "\n"; os << "use_power: " << use_power << "\n";
os << "device: " << device << "\n";
return os.str(); return os.str();
} }
}; };

View File

@ -21,8 +21,8 @@ std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts) {
return os; return os;
} }
FeatureWindowFunction::FeatureWindowFunction( FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts,
const FrameExtractionOptions &opts) { torch::Device device) {
int32_t frame_length = opts.WindowSize(); int32_t frame_length = opts.WindowSize();
KALDIFEAT_ASSERT(frame_length > 0); KALDIFEAT_ASSERT(frame_length > 0);
@ -54,6 +54,9 @@ FeatureWindowFunction::FeatureWindowFunction(
} }
window = window.unsqueeze(0); window = window.unsqueeze(0);
if (window.device() != device) {
window = window.to(device);
}
} }
void FeatureWindowFunction::Apply(torch::Tensor *wave) const { void FeatureWindowFunction::Apply(torch::Tensor *wave) const {

View File

@ -79,7 +79,8 @@ std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts);
class FeatureWindowFunction { class FeatureWindowFunction {
public: public:
FeatureWindowFunction() = default; FeatureWindowFunction() = default;
explicit FeatureWindowFunction(const FrameExtractionOptions &opts); FeatureWindowFunction(const FrameExtractionOptions &opts,
torch::Device device);
void Apply(torch::Tensor *wave) const; void Apply(torch::Tensor *wave) const;
private: private:

View File

@ -88,7 +88,7 @@ float MelBanks::VtlnWarpMelFreq(
MelBanks::MelBanks(const MelBanksOptions &opts, MelBanks::MelBanks(const MelBanksOptions &opts,
const FrameExtractionOptions &frame_opts, const FrameExtractionOptions &frame_opts,
float vtln_warp_factor) float vtln_warp_factor, torch::Device device)
: htk_mode_(opts.htk_mode) { : htk_mode_(opts.htk_mode) {
int32_t num_bins = opts.num_bins; int32_t num_bins = opts.num_bins;
if (num_bins < 3) KALDIFEAT_ERR << "Must have at least 3 mel bins"; if (num_bins < 3) KALDIFEAT_ERR << "Must have at least 3 mel bins";
@ -182,6 +182,10 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
if (debug_) KALDIFEAT_LOG << bins_mat_; if (debug_) KALDIFEAT_LOG << bins_mat_;
bins_mat_.t_(); bins_mat_.t_();
if (bins_mat_.device() != device) {
bins_mat_ = bins_mat_.to(device);
}
} }
torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const { torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const {

View File

@ -71,7 +71,8 @@ class MelBanks {
float vtln_warp_factor, float mel_freq); float vtln_warp_factor, float mel_freq);
MelBanks(const MelBanksOptions &opts, MelBanks(const MelBanksOptions &opts,
const FrameExtractionOptions &frame_opts, float vtln_warp_factor); const FrameExtractionOptions &frame_opts, float vtln_warp_factor,
torch::Device device);
int32_t NumBins() const { return static_cast<int32_t>(bins_mat_.sizes()[0]); } int32_t NumBins() const { return static_cast<int32_t>(bins_mat_.sizes()[0]); }

View File

@ -19,6 +19,22 @@ void PybindFbankOptions(py::module &m) {
.def_readwrite("htk_compat", &FbankOptions::htk_compat) .def_readwrite("htk_compat", &FbankOptions::htk_compat)
.def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank) .def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank)
.def_readwrite("use_power", &FbankOptions::use_power) .def_readwrite("use_power", &FbankOptions::use_power)
.def("set_device",
[](FbankOptions *fbank_opts, py::object device) {
std::string device_type =
static_cast<py::str>(device.attr("type"));
KALDIFEAT_ASSERT(device_type == "cpu" || device_type == "cuda")
<< "Unsupported device type: " << device_type;
auto index_attr = static_cast<py::object>(device.attr("index"));
int32_t device_index = 0;
if (!index_attr.is_none())
device_index = static_cast<py::int_>(index_attr);
if (device_type == "cpu")
fbank_opts->device = torch::Device("cpu");
else
fbank_opts->device = torch::Device(torch::kCUDA, device_index);
})
.def("__str__", [](const FbankOptions &self) -> std::string { .def("__str__", [](const FbankOptions &self) -> std::string {
return self.ToString(); return self.ToString();
}); });

View File

@ -56,61 +56,80 @@ def read_wave() -> torch.Tensor:
def test_and_benchmark_default_parameters(): def test_and_benchmark_default_parameters():
fbank_opts = _kaldifeat.FbankOptions() devices = [torch.device('cpu')]
fbank_opts.frame_opts.dither = 0 if torch.cuda.is_available():
fbank = _kaldifeat.Fbank(fbank_opts) devices.append(torch.device('cuda', 0))
data = read_wave() for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.set_device(device)
fbank = _kaldifeat.Fbank(fbank_opts)
ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(data, fbank) data = read_wave().to(device)
expected = read_ark_txt() ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(
assert torch.allclose(ans, expected, rtol=1e-3) data, fbank)
print('elapsed seconds:', elapsed_seconds)
expected = read_ark_txt()
assert torch.allclose(ans.cpu(), expected, rtol=1e-3)
print(f'elapsed seconds {device}:', elapsed_seconds)
def test_use_energy_htk_compat_true(): def test_use_energy_htk_compat_true():
fbank_opts = _kaldifeat.FbankOptions() devices = [torch.device('cpu')]
fbank_opts.frame_opts.dither = 0 if torch.cuda.is_available():
fbank_opts.use_energy = True devices.append(torch.device('cuda', 0))
fbank_opts.htk_compat = True
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave() for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.set_device(device)
fbank_opts.use_energy = True
fbank_opts.htk_compat = True
fbank = _kaldifeat.Fbank(fbank_opts)
ans = _kaldifeat.compute(data, fbank) data = read_wave().to(device)
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt ans = _kaldifeat.compute(data, fbank)
# the first 3 rows are:
expected_str = ''' # ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt
15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303 25.38995 # the first 3 rows are:
15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113 25.38996 expected_str = '''
15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 25.38995 15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303 25.38995
''' 15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113 25.38996
expected = parse_str(expected_str) 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 25.38995
assert torch.allclose(ans[:3, :], expected, rtol=1e-3) '''
expected = parse_str(expected_str)
assert torch.allclose(ans[:3, :].cpu(), expected, rtol=1e-3)
def test_use_energy_htk_compat_false(): def test_use_energy_htk_compat_false():
fbank_opts = _kaldifeat.FbankOptions() devices = [torch.device('cpu')]
fbank_opts.frame_opts.dither = 0 if torch.cuda.is_available():
fbank_opts.use_energy = True devices.append(torch.device('cuda', 0))
fbank_opts.htk_compat = False
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave() for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.use_energy = True
fbank_opts.htk_compat = False
fbank_opts.set_device(device)
fbank = _kaldifeat.Fbank(fbank_opts)
ans = _kaldifeat.compute(data, fbank) data = read_wave().to(device)
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt ans = _kaldifeat.compute(data, fbank)
# the first 3 rows are:
expected_str = ''' # ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt
25.38995 15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303 # the first 3 rows are:
25.38996 15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113 expected_str = '''
25.38995 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 25.38995 15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303
''' 25.38996 15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113
expected = parse_str(expected_str) 25.38995 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073
assert torch.allclose(ans[:3, :], expected, rtol=1e-3) '''
expected = parse_str(expected_str)
assert torch.allclose(ans[:3, :].cpu(), expected, rtol=1e-3)
def main(): def main():