support torch.device

This commit is contained in:
Fangjun Kuang 2021-02-27 23:40:47 +08:00
parent b2980cdffd
commit f909f839ab
9 changed files with 97 additions and 48 deletions

View File

@ -20,7 +20,7 @@ class OfflineFeatureTpl {
// using the options class, that we cache at this level.
OfflineFeatureTpl(const Options &opts)
: computer_(opts),
feature_window_function_(computer_.GetFrameOptions()) {}
feature_window_function_(computer_.GetFrameOptions(), opts.device) {}
/**
Computes the features for one file (one sequence of features).

View File

@ -37,7 +37,8 @@ const MelBanks *FbankComputer::GetMelBanks(float vtln_warp) {
// std::map<float, MelBanks *>::iterator iter = mel_banks_.find(vtln_warp);
auto iter = mel_banks_.find(vtln_warp);
if (iter == mel_banks_.end()) {
this_mel_banks = new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp);
this_mel_banks =
new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp, opts_.device);
mel_banks_[vtln_warp] = this_mel_banks;
} else {
this_mel_banks = iter->second;

View File

@ -12,6 +12,7 @@
#include "kaldifeat/csrc/feature-common.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/mel-computations.h"
#include "torch/torch.h"
namespace kaldifeat {
@ -37,7 +38,9 @@ struct FbankOptions {
// analysis, else magnitude.
bool use_power = true;
FbankOptions() { mel_opts.num_bins = 23; }
torch::Device device;
FbankOptions() : device("cpu") { mel_opts.num_bins = 23; }
std::string ToString() const {
std::ostringstream os;
@ -54,6 +57,7 @@ struct FbankOptions {
os << "htk_compat: " << htk_compat << "\n";
os << "use_log_fbank: " << use_log_fbank << "\n";
os << "use_power: " << use_power << "\n";
os << "device: " << device << "\n";
return os.str();
}
};

View File

@ -21,8 +21,8 @@ std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts) {
return os;
}
FeatureWindowFunction::FeatureWindowFunction(
const FrameExtractionOptions &opts) {
FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts,
torch::Device device) {
int32_t frame_length = opts.WindowSize();
KALDIFEAT_ASSERT(frame_length > 0);
@ -54,6 +54,9 @@ FeatureWindowFunction::FeatureWindowFunction(
}
window = window.unsqueeze(0);
if (window.device() != device) {
window = window.to(device);
}
}
void FeatureWindowFunction::Apply(torch::Tensor *wave) const {

View File

@ -79,7 +79,8 @@ std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts);
class FeatureWindowFunction {
public:
FeatureWindowFunction() = default;
explicit FeatureWindowFunction(const FrameExtractionOptions &opts);
FeatureWindowFunction(const FrameExtractionOptions &opts,
torch::Device device);
void Apply(torch::Tensor *wave) const;
private:

View File

@ -88,7 +88,7 @@ float MelBanks::VtlnWarpMelFreq(
MelBanks::MelBanks(const MelBanksOptions &opts,
const FrameExtractionOptions &frame_opts,
float vtln_warp_factor)
float vtln_warp_factor, torch::Device device)
: htk_mode_(opts.htk_mode) {
int32_t num_bins = opts.num_bins;
if (num_bins < 3) KALDIFEAT_ERR << "Must have at least 3 mel bins";
@ -182,6 +182,10 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
if (debug_) KALDIFEAT_LOG << bins_mat_;
bins_mat_.t_();
if (bins_mat_.device() != device) {
bins_mat_ = bins_mat_.to(device);
}
}
torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const {

View File

@ -71,7 +71,8 @@ class MelBanks {
float vtln_warp_factor, float mel_freq);
MelBanks(const MelBanksOptions &opts,
const FrameExtractionOptions &frame_opts, float vtln_warp_factor);
const FrameExtractionOptions &frame_opts, float vtln_warp_factor,
torch::Device device);
int32_t NumBins() const { return static_cast<int32_t>(bins_mat_.sizes()[0]); }

View File

@ -19,6 +19,22 @@ void PybindFbankOptions(py::module &m) {
.def_readwrite("htk_compat", &FbankOptions::htk_compat)
.def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank)
.def_readwrite("use_power", &FbankOptions::use_power)
.def("set_device",
[](FbankOptions *fbank_opts, py::object device) {
std::string device_type =
static_cast<py::str>(device.attr("type"));
KALDIFEAT_ASSERT(device_type == "cpu" || device_type == "cuda")
<< "Unsupported device type: " << device_type;
auto index_attr = static_cast<py::object>(device.attr("index"));
int32_t device_index = 0;
if (!index_attr.is_none())
device_index = static_cast<py::int_>(index_attr);
if (device_type == "cpu")
fbank_opts->device = torch::Device("cpu");
else
fbank_opts->device = torch::Device(torch::kCUDA, device_index);
})
.def("__str__", [](const FbankOptions &self) -> std::string {
return self.ToString();
});

View File

@ -56,61 +56,80 @@ def read_wave() -> torch.Tensor:
def test_and_benchmark_default_parameters():
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank = _kaldifeat.Fbank(fbank_opts)
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices.append(torch.device('cuda', 0))
data = read_wave()
for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.set_device(device)
fbank = _kaldifeat.Fbank(fbank_opts)
ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(data, fbank)
data = read_wave().to(device)
expected = read_ark_txt()
assert torch.allclose(ans, expected, rtol=1e-3)
print('elapsed seconds:', elapsed_seconds)
ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(
data, fbank)
expected = read_ark_txt()
assert torch.allclose(ans.cpu(), expected, rtol=1e-3)
print(f'elapsed seconds {device}:', elapsed_seconds)
def test_use_energy_htk_compat_true():
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.use_energy = True
fbank_opts.htk_compat = True
fbank = _kaldifeat.Fbank(fbank_opts)
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices.append(torch.device('cuda', 0))
data = read_wave()
for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.set_device(device)
fbank_opts.use_energy = True
fbank_opts.htk_compat = True
fbank = _kaldifeat.Fbank(fbank_opts)
ans = _kaldifeat.compute(data, fbank)
data = read_wave().to(device)
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt
# the first 3 rows are:
expected_str = '''
15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303 25.38995
15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113 25.38996
15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 25.38995
'''
expected = parse_str(expected_str)
assert torch.allclose(ans[:3, :], expected, rtol=1e-3)
ans = _kaldifeat.compute(data, fbank)
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt
# the first 3 rows are:
expected_str = '''
15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303 25.38995
15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113 25.38996
15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 25.38995
'''
expected = parse_str(expected_str)
assert torch.allclose(ans[:3, :].cpu(), expected, rtol=1e-3)
def test_use_energy_htk_compat_false():
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.use_energy = True
fbank_opts.htk_compat = False
fbank = _kaldifeat.Fbank(fbank_opts)
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices.append(torch.device('cuda', 0))
data = read_wave()
for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.use_energy = True
fbank_opts.htk_compat = False
fbank_opts.set_device(device)
fbank = _kaldifeat.Fbank(fbank_opts)
ans = _kaldifeat.compute(data, fbank)
data = read_wave().to(device)
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt
# the first 3 rows are:
expected_str = '''
25.38995 15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303
25.38996 15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113
25.38995 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073
'''
expected = parse_str(expected_str)
assert torch.allclose(ans[:3, :], expected, rtol=1e-3)
ans = _kaldifeat.compute(data, fbank)
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt
# the first 3 rows are:
expected_str = '''
25.38995 15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303
25.38996 15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113
25.38995 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073
'''
expected = parse_str(expected_str)
assert torch.allclose(ans[:3, :].cpu(), expected, rtol=1e-3)
def main():