mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 18:12:17 +00:00
support torch.device
This commit is contained in:
parent
b2980cdffd
commit
f909f839ab
@ -20,7 +20,7 @@ class OfflineFeatureTpl {
|
||||
// using the options class, that we cache at this level.
|
||||
OfflineFeatureTpl(const Options &opts)
|
||||
: computer_(opts),
|
||||
feature_window_function_(computer_.GetFrameOptions()) {}
|
||||
feature_window_function_(computer_.GetFrameOptions(), opts.device) {}
|
||||
|
||||
/**
|
||||
Computes the features for one file (one sequence of features).
|
||||
|
@ -37,7 +37,8 @@ const MelBanks *FbankComputer::GetMelBanks(float vtln_warp) {
|
||||
// std::map<float, MelBanks *>::iterator iter = mel_banks_.find(vtln_warp);
|
||||
auto iter = mel_banks_.find(vtln_warp);
|
||||
if (iter == mel_banks_.end()) {
|
||||
this_mel_banks = new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp);
|
||||
this_mel_banks =
|
||||
new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp, opts_.device);
|
||||
mel_banks_[vtln_warp] = this_mel_banks;
|
||||
} else {
|
||||
this_mel_banks = iter->second;
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "kaldifeat/csrc/feature-common.h"
|
||||
#include "kaldifeat/csrc/feature-window.h"
|
||||
#include "kaldifeat/csrc/mel-computations.h"
|
||||
#include "torch/torch.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
@ -37,7 +38,9 @@ struct FbankOptions {
|
||||
// analysis, else magnitude.
|
||||
bool use_power = true;
|
||||
|
||||
FbankOptions() { mel_opts.num_bins = 23; }
|
||||
torch::Device device;
|
||||
|
||||
FbankOptions() : device("cpu") { mel_opts.num_bins = 23; }
|
||||
|
||||
std::string ToString() const {
|
||||
std::ostringstream os;
|
||||
@ -54,6 +57,7 @@ struct FbankOptions {
|
||||
os << "htk_compat: " << htk_compat << "\n";
|
||||
os << "use_log_fbank: " << use_log_fbank << "\n";
|
||||
os << "use_power: " << use_power << "\n";
|
||||
os << "device: " << device << "\n";
|
||||
return os.str();
|
||||
}
|
||||
};
|
||||
|
@ -21,8 +21,8 @@ std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts) {
|
||||
return os;
|
||||
}
|
||||
|
||||
FeatureWindowFunction::FeatureWindowFunction(
|
||||
const FrameExtractionOptions &opts) {
|
||||
FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts,
|
||||
torch::Device device) {
|
||||
int32_t frame_length = opts.WindowSize();
|
||||
KALDIFEAT_ASSERT(frame_length > 0);
|
||||
|
||||
@ -54,6 +54,9 @@ FeatureWindowFunction::FeatureWindowFunction(
|
||||
}
|
||||
|
||||
window = window.unsqueeze(0);
|
||||
if (window.device() != device) {
|
||||
window = window.to(device);
|
||||
}
|
||||
}
|
||||
|
||||
void FeatureWindowFunction::Apply(torch::Tensor *wave) const {
|
||||
|
@ -79,7 +79,8 @@ std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts);
|
||||
class FeatureWindowFunction {
|
||||
public:
|
||||
FeatureWindowFunction() = default;
|
||||
explicit FeatureWindowFunction(const FrameExtractionOptions &opts);
|
||||
FeatureWindowFunction(const FrameExtractionOptions &opts,
|
||||
torch::Device device);
|
||||
void Apply(torch::Tensor *wave) const;
|
||||
|
||||
private:
|
||||
|
@ -88,7 +88,7 @@ float MelBanks::VtlnWarpMelFreq(
|
||||
|
||||
MelBanks::MelBanks(const MelBanksOptions &opts,
|
||||
const FrameExtractionOptions &frame_opts,
|
||||
float vtln_warp_factor)
|
||||
float vtln_warp_factor, torch::Device device)
|
||||
: htk_mode_(opts.htk_mode) {
|
||||
int32_t num_bins = opts.num_bins;
|
||||
if (num_bins < 3) KALDIFEAT_ERR << "Must have at least 3 mel bins";
|
||||
@ -182,6 +182,10 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
|
||||
if (debug_) KALDIFEAT_LOG << bins_mat_;
|
||||
|
||||
bins_mat_.t_();
|
||||
|
||||
if (bins_mat_.device() != device) {
|
||||
bins_mat_ = bins_mat_.to(device);
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const {
|
||||
|
@ -71,7 +71,8 @@ class MelBanks {
|
||||
float vtln_warp_factor, float mel_freq);
|
||||
|
||||
MelBanks(const MelBanksOptions &opts,
|
||||
const FrameExtractionOptions &frame_opts, float vtln_warp_factor);
|
||||
const FrameExtractionOptions &frame_opts, float vtln_warp_factor,
|
||||
torch::Device device);
|
||||
|
||||
int32_t NumBins() const { return static_cast<int32_t>(bins_mat_.sizes()[0]); }
|
||||
|
||||
|
@ -19,6 +19,22 @@ void PybindFbankOptions(py::module &m) {
|
||||
.def_readwrite("htk_compat", &FbankOptions::htk_compat)
|
||||
.def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank)
|
||||
.def_readwrite("use_power", &FbankOptions::use_power)
|
||||
.def("set_device",
|
||||
[](FbankOptions *fbank_opts, py::object device) {
|
||||
std::string device_type =
|
||||
static_cast<py::str>(device.attr("type"));
|
||||
KALDIFEAT_ASSERT(device_type == "cpu" || device_type == "cuda")
|
||||
<< "Unsupported device type: " << device_type;
|
||||
|
||||
auto index_attr = static_cast<py::object>(device.attr("index"));
|
||||
int32_t device_index = 0;
|
||||
if (!index_attr.is_none())
|
||||
device_index = static_cast<py::int_>(index_attr);
|
||||
if (device_type == "cpu")
|
||||
fbank_opts->device = torch::Device("cpu");
|
||||
else
|
||||
fbank_opts->device = torch::Device(torch::kCUDA, device_index);
|
||||
})
|
||||
.def("__str__", [](const FbankOptions &self) -> std::string {
|
||||
return self.ToString();
|
||||
});
|
||||
|
@ -56,61 +56,80 @@ def read_wave() -> torch.Tensor:
|
||||
|
||||
|
||||
def test_and_benchmark_default_parameters():
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
devices = [torch.device('cpu')]
|
||||
if torch.cuda.is_available():
|
||||
devices.append(torch.device('cuda', 0))
|
||||
|
||||
data = read_wave()
|
||||
for device in devices:
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank_opts.set_device(device)
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
|
||||
ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(data, fbank)
|
||||
data = read_wave().to(device)
|
||||
|
||||
expected = read_ark_txt()
|
||||
assert torch.allclose(ans, expected, rtol=1e-3)
|
||||
print('elapsed seconds:', elapsed_seconds)
|
||||
ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(
|
||||
data, fbank)
|
||||
|
||||
expected = read_ark_txt()
|
||||
assert torch.allclose(ans.cpu(), expected, rtol=1e-3)
|
||||
print(f'elapsed seconds {device}:', elapsed_seconds)
|
||||
|
||||
|
||||
def test_use_energy_htk_compat_true():
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank_opts.use_energy = True
|
||||
fbank_opts.htk_compat = True
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
devices = [torch.device('cpu')]
|
||||
if torch.cuda.is_available():
|
||||
devices.append(torch.device('cuda', 0))
|
||||
|
||||
data = read_wave()
|
||||
for device in devices:
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank_opts.set_device(device)
|
||||
fbank_opts.use_energy = True
|
||||
fbank_opts.htk_compat = True
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
|
||||
ans = _kaldifeat.compute(data, fbank)
|
||||
data = read_wave().to(device)
|
||||
|
||||
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt
|
||||
# the first 3 rows are:
|
||||
expected_str = '''
|
||||
15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303 25.38995
|
||||
15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113 25.38996
|
||||
15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 25.38995
|
||||
'''
|
||||
expected = parse_str(expected_str)
|
||||
assert torch.allclose(ans[:3, :], expected, rtol=1e-3)
|
||||
ans = _kaldifeat.compute(data, fbank)
|
||||
|
||||
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt
|
||||
# the first 3 rows are:
|
||||
expected_str = '''
|
||||
15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303 25.38995
|
||||
15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113 25.38996
|
||||
15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 25.38995
|
||||
'''
|
||||
expected = parse_str(expected_str)
|
||||
assert torch.allclose(ans[:3, :].cpu(), expected, rtol=1e-3)
|
||||
|
||||
|
||||
def test_use_energy_htk_compat_false():
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank_opts.use_energy = True
|
||||
fbank_opts.htk_compat = False
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
devices = [torch.device('cpu')]
|
||||
if torch.cuda.is_available():
|
||||
devices.append(torch.device('cuda', 0))
|
||||
|
||||
data = read_wave()
|
||||
for device in devices:
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank_opts.use_energy = True
|
||||
fbank_opts.htk_compat = False
|
||||
fbank_opts.set_device(device)
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
|
||||
ans = _kaldifeat.compute(data, fbank)
|
||||
data = read_wave().to(device)
|
||||
|
||||
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt
|
||||
# the first 3 rows are:
|
||||
expected_str = '''
|
||||
25.38995 15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303
|
||||
25.38996 15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113
|
||||
25.38995 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073
|
||||
'''
|
||||
expected = parse_str(expected_str)
|
||||
assert torch.allclose(ans[:3, :], expected, rtol=1e-3)
|
||||
ans = _kaldifeat.compute(data, fbank)
|
||||
|
||||
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt
|
||||
# the first 3 rows are:
|
||||
expected_str = '''
|
||||
25.38995 15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303
|
||||
25.38996 15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113
|
||||
25.38995 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073
|
||||
'''
|
||||
expected = parse_str(expected_str)
|
||||
assert torch.allclose(ans[:3, :].cpu(), expected, rtol=1e-3)
|
||||
|
||||
|
||||
def main():
|
||||
|
Loading…
x
Reference in New Issue
Block a user