support torch.device

This commit is contained in:
Fangjun Kuang 2021-02-27 23:40:47 +08:00
parent b2980cdffd
commit f909f839ab
9 changed files with 97 additions and 48 deletions

View File

@ -20,7 +20,7 @@ class OfflineFeatureTpl {
// using the options class, that we cache at this level.
OfflineFeatureTpl(const Options &opts)
: computer_(opts),
feature_window_function_(computer_.GetFrameOptions()) {}
feature_window_function_(computer_.GetFrameOptions(), opts.device) {}
/**
Computes the features for one file (one sequence of features).

View File

@ -37,7 +37,8 @@ const MelBanks *FbankComputer::GetMelBanks(float vtln_warp) {
// std::map<float, MelBanks *>::iterator iter = mel_banks_.find(vtln_warp);
auto iter = mel_banks_.find(vtln_warp);
if (iter == mel_banks_.end()) {
this_mel_banks = new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp);
this_mel_banks =
new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp, opts_.device);
mel_banks_[vtln_warp] = this_mel_banks;
} else {
this_mel_banks = iter->second;

View File

@ -12,6 +12,7 @@
#include "kaldifeat/csrc/feature-common.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/mel-computations.h"
#include "torch/torch.h"
namespace kaldifeat {
@ -37,7 +38,9 @@ struct FbankOptions {
// analysis, else magnitude.
bool use_power = true;
FbankOptions() { mel_opts.num_bins = 23; }
torch::Device device;
FbankOptions() : device("cpu") { mel_opts.num_bins = 23; }
std::string ToString() const {
std::ostringstream os;
@ -54,6 +57,7 @@ struct FbankOptions {
os << "htk_compat: " << htk_compat << "\n";
os << "use_log_fbank: " << use_log_fbank << "\n";
os << "use_power: " << use_power << "\n";
os << "device: " << device << "\n";
return os.str();
}
};

View File

@ -21,8 +21,8 @@ std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts) {
return os;
}
FeatureWindowFunction::FeatureWindowFunction(
const FrameExtractionOptions &opts) {
FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts,
torch::Device device) {
int32_t frame_length = opts.WindowSize();
KALDIFEAT_ASSERT(frame_length > 0);
@ -54,6 +54,9 @@ FeatureWindowFunction::FeatureWindowFunction(
}
window = window.unsqueeze(0);
if (window.device() != device) {
window = window.to(device);
}
}
void FeatureWindowFunction::Apply(torch::Tensor *wave) const {

View File

@ -79,7 +79,8 @@ std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts);
class FeatureWindowFunction {
public:
FeatureWindowFunction() = default;
explicit FeatureWindowFunction(const FrameExtractionOptions &opts);
FeatureWindowFunction(const FrameExtractionOptions &opts,
torch::Device device);
void Apply(torch::Tensor *wave) const;
private:

View File

@ -88,7 +88,7 @@ float MelBanks::VtlnWarpMelFreq(
MelBanks::MelBanks(const MelBanksOptions &opts,
const FrameExtractionOptions &frame_opts,
float vtln_warp_factor)
float vtln_warp_factor, torch::Device device)
: htk_mode_(opts.htk_mode) {
int32_t num_bins = opts.num_bins;
if (num_bins < 3) KALDIFEAT_ERR << "Must have at least 3 mel bins";
@ -182,6 +182,10 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
if (debug_) KALDIFEAT_LOG << bins_mat_;
bins_mat_.t_();
if (bins_mat_.device() != device) {
bins_mat_ = bins_mat_.to(device);
}
}
torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const {

View File

@ -71,7 +71,8 @@ class MelBanks {
float vtln_warp_factor, float mel_freq);
MelBanks(const MelBanksOptions &opts,
const FrameExtractionOptions &frame_opts, float vtln_warp_factor);
const FrameExtractionOptions &frame_opts, float vtln_warp_factor,
torch::Device device);
int32_t NumBins() const { return static_cast<int32_t>(bins_mat_.sizes()[0]); }

View File

@ -19,6 +19,22 @@ void PybindFbankOptions(py::module &m) {
.def_readwrite("htk_compat", &FbankOptions::htk_compat)
.def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank)
.def_readwrite("use_power", &FbankOptions::use_power)
.def("set_device",
[](FbankOptions *fbank_opts, py::object device) {
std::string device_type =
static_cast<py::str>(device.attr("type"));
KALDIFEAT_ASSERT(device_type == "cpu" || device_type == "cuda")
<< "Unsupported device type: " << device_type;
auto index_attr = static_cast<py::object>(device.attr("index"));
int32_t device_index = 0;
if (!index_attr.is_none())
device_index = static_cast<py::int_>(index_attr);
if (device_type == "cpu")
fbank_opts->device = torch::Device("cpu");
else
fbank_opts->device = torch::Device(torch::kCUDA, device_index);
})
.def("__str__", [](const FbankOptions &self) -> std::string {
return self.ToString();
});

View File

@ -56,27 +56,40 @@ def read_wave() -> torch.Tensor:
def test_and_benchmark_default_parameters():
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices.append(torch.device('cuda', 0))
for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.set_device(device)
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave()
data = read_wave().to(device)
ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(data, fbank)
ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(
data, fbank)
expected = read_ark_txt()
assert torch.allclose(ans, expected, rtol=1e-3)
print('elapsed seconds:', elapsed_seconds)
assert torch.allclose(ans.cpu(), expected, rtol=1e-3)
print(f'elapsed seconds {device}:', elapsed_seconds)
def test_use_energy_htk_compat_true():
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices.append(torch.device('cuda', 0))
for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.set_device(device)
fbank_opts.use_energy = True
fbank_opts.htk_compat = True
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave()
data = read_wave().to(device)
ans = _kaldifeat.compute(data, fbank)
@ -88,17 +101,23 @@ def test_use_energy_htk_compat_true():
15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 25.38995
'''
expected = parse_str(expected_str)
assert torch.allclose(ans[:3, :], expected, rtol=1e-3)
assert torch.allclose(ans[:3, :].cpu(), expected, rtol=1e-3)
def test_use_energy_htk_compat_false():
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices.append(torch.device('cuda', 0))
for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.use_energy = True
fbank_opts.htk_compat = False
fbank_opts.set_device(device)
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave()
data = read_wave().to(device)
ans = _kaldifeat.compute(data, fbank)
@ -110,7 +129,7 @@ def test_use_energy_htk_compat_false():
25.38995 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073
'''
expected = parse_str(expected_str)
assert torch.allclose(ans[:3, :], expected, rtol=1e-3)
assert torch.allclose(ans[:3, :].cpu(), expected, rtol=1e-3)
def main():