mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 01:52:39 +00:00
Add OnlineFbank python APIs.
This commit is contained in:
parent
039e27dd32
commit
e59d05a45a
@ -240,6 +240,7 @@ torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
|
|||||||
|
|
||||||
p_window[s] = p_wave[s_in_wave];
|
p_window[s] = p_wave[s_in_wave];
|
||||||
}
|
}
|
||||||
|
return window;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,7 +56,8 @@ class OnlineFeatureInterface {
|
|||||||
/// it's more efficient to do things in a batch).
|
/// it's more efficient to do things in a batch).
|
||||||
///
|
///
|
||||||
/// The returned tensor has shape (frames.size(), Dim()).
|
/// The returned tensor has shape (frames.size(), Dim()).
|
||||||
virtual torch::Tensor GetFrames(const std::vector<int32_t> &frames) {
|
virtual std::vector<torch::Tensor> GetFrames(
|
||||||
|
const std::vector<int32_t> &frames) {
|
||||||
std::vector<torch::Tensor> features;
|
std::vector<torch::Tensor> features;
|
||||||
features.reserve(frames.size());
|
features.reserve(frames.size());
|
||||||
|
|
||||||
@ -64,8 +65,9 @@ class OnlineFeatureInterface {
|
|||||||
torch::Tensor f = GetFrame(i);
|
torch::Tensor f = GetFrame(i);
|
||||||
features.push_back(std::move(f));
|
features.push_back(std::move(f));
|
||||||
}
|
}
|
||||||
|
return features;
|
||||||
|
|
||||||
return torch::cat(features, /*dim*/ 0);
|
// return torch::cat(features, [>dim<] 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This would be called from the application, when you get more wave data.
|
/// This would be called from the application, when you get more wave data.
|
||||||
|
@ -48,12 +48,7 @@ OnlineGenericBaseFeature<C>::OnlineGenericBaseFeature(
|
|||||||
window_function_(opts.frame_opts, opts.device),
|
window_function_(opts.frame_opts, opts.device),
|
||||||
features_(opts.frame_opts.max_feature_vectors),
|
features_(opts.frame_opts.max_feature_vectors),
|
||||||
input_finished_(false),
|
input_finished_(false),
|
||||||
waveform_offset_(0) {
|
waveform_offset_(0) {}
|
||||||
// Casting to uint32_t, an unsigned type, means that -1 would be treated
|
|
||||||
// as `very large`.
|
|
||||||
KALDIFEAT_ASSERT(static_cast<uint32_t>(opts.frame_opts.max_feature_vectors) >
|
|
||||||
200);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class C>
|
template <class C>
|
||||||
void OnlineGenericBaseFeature<C>::AcceptWaveform(
|
void OnlineGenericBaseFeature<C>::AcceptWaveform(
|
||||||
@ -61,6 +56,7 @@ void OnlineGenericBaseFeature<C>::AcceptWaveform(
|
|||||||
if (original_waveform.numel() == 0) return; // Nothing to do.
|
if (original_waveform.numel() == 0) return; // Nothing to do.
|
||||||
|
|
||||||
KALDIFEAT_ASSERT(original_waveform.dim() == 1);
|
KALDIFEAT_ASSERT(original_waveform.dim() == 1);
|
||||||
|
KALDIFEAT_ASSERT(sampling_rate == computer_.GetFrameOptions().samp_freq);
|
||||||
|
|
||||||
if (input_finished_)
|
if (input_finished_)
|
||||||
KALDIFEAT_ERR << "AcceptWaveform called after InputFinished() was called.";
|
KALDIFEAT_ERR << "AcceptWaveform called after InputFinished() was called.";
|
||||||
|
@ -7,6 +7,7 @@ pybind11_add_module(_kaldifeat
|
|||||||
feature-window.cc
|
feature-window.cc
|
||||||
kaldifeat.cc
|
kaldifeat.cc
|
||||||
mel-computations.cc
|
mel-computations.cc
|
||||||
|
online-feature.cc
|
||||||
utils.cc
|
utils.cc
|
||||||
)
|
)
|
||||||
target_link_libraries(_kaldifeat PRIVATE kaldifeat_core)
|
target_link_libraries(_kaldifeat PRIVATE kaldifeat_core)
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
#include "kaldifeat/python/csrc/feature-spectrogram.h"
|
#include "kaldifeat/python/csrc/feature-spectrogram.h"
|
||||||
#include "kaldifeat/python/csrc/feature-window.h"
|
#include "kaldifeat/python/csrc/feature-window.h"
|
||||||
#include "kaldifeat/python/csrc/mel-computations.h"
|
#include "kaldifeat/python/csrc/mel-computations.h"
|
||||||
|
#include "kaldifeat/python/csrc/online-feature.h"
|
||||||
#include "torch/torch.h"
|
#include "torch/torch.h"
|
||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
@ -24,6 +25,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
|
|||||||
PybindFeatureMfcc(m);
|
PybindFeatureMfcc(m);
|
||||||
PybindFeaturePlp(m);
|
PybindFeaturePlp(m);
|
||||||
PybindFeatureSpectrogram(m);
|
PybindFeatureSpectrogram(m);
|
||||||
|
PybindOnlineFeature(m);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace kaldifeat
|
} // namespace kaldifeat
|
||||||
|
36
kaldifeat/python/csrc/online-feature.cc
Normal file
36
kaldifeat/python/csrc/online-feature.cc
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
// kaldifeat/python/csrc/online-feature.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
#include "kaldifeat/python/csrc/online-feature.h"
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/online-feature.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
template <typename C>
|
||||||
|
void PybindOnlineFeatureTpl(py::module &m, const std::string &class_name,
|
||||||
|
const std::string &class_help_doc = "") {
|
||||||
|
using PyClass = OnlineGenericBaseFeature<C>;
|
||||||
|
using Options = typename C::Options;
|
||||||
|
py::class_<PyClass>(m, class_name.c_str(), class_help_doc.c_str())
|
||||||
|
.def(py::init<const Options &>(), py::arg("opts"))
|
||||||
|
.def_property_readonly("dim", &PyClass::Dim)
|
||||||
|
.def_property_readonly("frame_shift_in_seconds",
|
||||||
|
&PyClass::FrameShiftInSeconds)
|
||||||
|
.def_property_readonly("num_frames_ready", &PyClass::NumFramesReady)
|
||||||
|
.def("is_last_frame", &PyClass::IsLastFrame, py::arg("frame"))
|
||||||
|
.def("get_frame", &PyClass::GetFrame, py::arg("frame"))
|
||||||
|
.def("get_frames", &PyClass::GetFrames, py::arg("frames"))
|
||||||
|
.def("accept_waveform", &PyClass::AcceptWaveform,
|
||||||
|
py::arg("sampling_rate"), py::arg("waveform"))
|
||||||
|
.def("input_finished", &PyClass::InputFinished);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PybindOnlineFeature(py::module &m) {
|
||||||
|
PybindOnlineFeatureTpl<Mfcc>(m, "OnlineMfcc");
|
||||||
|
PybindOnlineFeatureTpl<Fbank>(m, "OnlineFbank");
|
||||||
|
PybindOnlineFeatureTpl<Plp>(m, "OnlinePlp");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
16
kaldifeat/python/csrc/online-feature.h
Normal file
16
kaldifeat/python/csrc/online-feature.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
// kaldifeat/python/csrc/online-feature.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
#ifndef KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_
|
||||||
|
#define KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_
|
||||||
|
|
||||||
|
#include "kaldifeat/python/csrc/kaldifeat.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
void PybindOnlineFeature(py::module &m);
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
||||||
|
|
||||||
|
#endif // KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_
|
@ -8,7 +8,7 @@ from _kaldifeat import (
|
|||||||
SpectrogramOptions,
|
SpectrogramOptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .fbank import Fbank
|
from .fbank import Fbank, OnlineFbank
|
||||||
from .mfcc import Mfcc
|
from .mfcc import Mfcc
|
||||||
from .plp import Plp
|
from .plp import Plp
|
||||||
from .spectrogram import Spectrogram
|
from .spectrogram import Spectrogram
|
||||||
|
@ -4,9 +4,20 @@
|
|||||||
import _kaldifeat
|
import _kaldifeat
|
||||||
|
|
||||||
from .offline_feature import OfflineFeature
|
from .offline_feature import OfflineFeature
|
||||||
|
from .online_feature import OnlineFeature
|
||||||
|
|
||||||
|
|
||||||
class Fbank(OfflineFeature):
|
class Fbank(OfflineFeature):
|
||||||
def __init__(self, opts: _kaldifeat.FbankOptions):
|
def __init__(self, opts: _kaldifeat.FbankOptions):
|
||||||
super().__init__(opts)
|
super().__init__(opts)
|
||||||
self.computer = _kaldifeat.Fbank(opts)
|
self.computer = _kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineFbank(OnlineFeature):
|
||||||
|
def __init__(self, opts: _kaldifeat.FbankOptions):
|
||||||
|
super().__init__(opts)
|
||||||
|
self.computer = _kaldifeat.OnlineFbank(opts)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.opts = _kaldifeat.FbankOptions.from_dict(state)
|
||||||
|
self.computer = _kaldifeat.Fbank(self.opts)
|
||||||
|
95
kaldifeat/python/kaldifeat/online_feature.py
Normal file
95
kaldifeat/python/kaldifeat/online_feature.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
# Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineFeature(object):
|
||||||
|
"""Offline feature is a base class of other feature computers,
|
||||||
|
e.g., Fbank, Mfcc.
|
||||||
|
|
||||||
|
This class has two fields:
|
||||||
|
|
||||||
|
(1) opts. It contains the options for the feature computer.
|
||||||
|
(2) computer. The actual feature computer. It should be
|
||||||
|
instantiated by subclasses.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
It supports only CPU at present.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, opts):
|
||||||
|
assert opts.device.type == "cpu"
|
||||||
|
|
||||||
|
self.opts = opts
|
||||||
|
|
||||||
|
# self.computer is expected to be set by subclasses
|
||||||
|
self.computer = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_frames_ready(self) -> int:
|
||||||
|
"""Return the number of ready frames.
|
||||||
|
|
||||||
|
It can be updated by :method:`accept_waveform`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If you set ``opts.frame_opts.max_feature_vectors``, then
|
||||||
|
the valid frame indexes are in the range.
|
||||||
|
``[num_frames_ready - max_feature_vectors, num_frames_ready)``
|
||||||
|
|
||||||
|
If you leave ``opts.frame_opts.max_feature_vectors`` to its default
|
||||||
|
value, then the range is ``[0, num_frames_ready)``
|
||||||
|
"""
|
||||||
|
return self.computer.num_frames_ready
|
||||||
|
|
||||||
|
def is_last_frame(self, frame: int) -> bool:
|
||||||
|
"""Return True if the given frame is the last frame."""
|
||||||
|
return self.computer.is_last_frame(frame)
|
||||||
|
|
||||||
|
def get_frame(self, frame: int) -> torch.Tensor:
|
||||||
|
"""Get the frame by its index.
|
||||||
|
Args:
|
||||||
|
frame:
|
||||||
|
The frame index. If ``opts.frame_opts.max_feature_vectors`` is
|
||||||
|
-1, then its valid values are in the range
|
||||||
|
``[0, num_frames_ready)``. Otherwise, the range is
|
||||||
|
``[num_frames_ready - max_feature_vectors, num_frames_ready)``.
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor with shape ``(1, feature_dim)``
|
||||||
|
"""
|
||||||
|
return self.computer.get_frame(frame)
|
||||||
|
|
||||||
|
def get_frames(self, frames: List[int]) -> List[torch.Tensor]:
|
||||||
|
"""Get frames at the given frame indexes.
|
||||||
|
Args:
|
||||||
|
frames:
|
||||||
|
Frames whose indexes are in this list are returned.
|
||||||
|
Returns:
|
||||||
|
Return a list of feature frames at the given indexes.
|
||||||
|
"""
|
||||||
|
return self.computer.get_frames(frames)
|
||||||
|
|
||||||
|
def accept_waveform(
|
||||||
|
self, sampling_rate: float, waveform: torch.Tensor
|
||||||
|
) -> None:
|
||||||
|
"""Send audio samples to the extractor.
|
||||||
|
Args:
|
||||||
|
sampling_rate:
|
||||||
|
The sampling rate of the given audio samples. It has to be equal
|
||||||
|
to ``opts.frame_opts.samp_freq``.
|
||||||
|
waveform:
|
||||||
|
A 1-D tensor of shape (num_samples,). Its dtype is torch.float32
|
||||||
|
and has to be on CPU.
|
||||||
|
"""
|
||||||
|
self.computer.accept_waveform(sampling_rate, waveform)
|
||||||
|
|
||||||
|
def input_finished(self) -> None:
|
||||||
|
"""Tell the extractor that no more audio samples will be available.
|
||||||
|
After calling this function, you cannot invoke ``accept_waveform``
|
||||||
|
again.
|
||||||
|
"""
|
||||||
|
self.computer.input_finished()
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
return self.opts.as_dict()
|
@ -1,6 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -13,30 +13,88 @@ import kaldifeat
|
|||||||
cur_dir = Path(__file__).resolve().parent
|
cur_dir = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
|
||||||
|
def test_online_fbank(
|
||||||
|
opts: kaldifeat.FbankOptions,
|
||||||
|
wave: torch.Tensor,
|
||||||
|
cpu_features: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
opts:
|
||||||
|
The options to create the online fbank extractor.
|
||||||
|
wave:
|
||||||
|
The input 1-D waveform.
|
||||||
|
cpu_features:
|
||||||
|
The groud truth features that are computed offline
|
||||||
|
"""
|
||||||
|
online_fbank = kaldifeat.OnlineFbank(opts)
|
||||||
|
|
||||||
|
num_processed_frames = 0
|
||||||
|
i = 0 # current sample index to feed
|
||||||
|
while not online_fbank.is_last_frame(num_processed_frames - 1):
|
||||||
|
while num_processed_frames < online_fbank.num_frames_ready:
|
||||||
|
# There are new frames to be processed
|
||||||
|
frame = online_fbank.get_frame(num_processed_frames)
|
||||||
|
assert torch.allclose(
|
||||||
|
frame.squeeze(0), cpu_features[num_processed_frames]
|
||||||
|
)
|
||||||
|
num_processed_frames += 1
|
||||||
|
|
||||||
|
# Simulate streaming . Send a random number of audio samples
|
||||||
|
# to the extractor
|
||||||
|
num_samples = torch.randint(300, 1000, (1,)).item()
|
||||||
|
|
||||||
|
samples = wave[i : (i + num_samples)] # noqa
|
||||||
|
i += num_samples
|
||||||
|
if len(samples) == 0:
|
||||||
|
online_fbank.input_finished()
|
||||||
|
continue
|
||||||
|
|
||||||
|
online_fbank.accept_waveform(16000, samples)
|
||||||
|
|
||||||
|
assert num_processed_frames == online_fbank.num_frames_ready
|
||||||
|
assert num_processed_frames == cpu_features.size(0)
|
||||||
|
|
||||||
|
|
||||||
def test_fbank_default():
|
def test_fbank_default():
|
||||||
print("=====test_fbank_default=====")
|
print("=====test_fbank_default=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.FbankOptions()
|
opts = kaldifeat.FbankOptions()
|
||||||
opts.device = device
|
opts.device = device
|
||||||
opts.frame_opts.dither = 0
|
opts.frame_opts.dither = 0
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename)
|
|
||||||
|
|
||||||
features = fbank(wave)
|
features = fbank(wave)
|
||||||
assert features.device.type == "cpu"
|
assert features.device.type == "cpu"
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test.txt")
|
|
||||||
assert torch.allclose(features, gt, rtol=1e-1)
|
assert torch.allclose(features, gt, rtol=1e-1)
|
||||||
|
if cpu_features is None:
|
||||||
|
cpu_features = features
|
||||||
|
|
||||||
wave = wave.to(device)
|
features = fbank(wave.to(device))
|
||||||
features = fbank(wave)
|
|
||||||
assert features.device == device
|
assert features.device == device
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
# Now for online fbank
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.max_feature_vectors = 100
|
||||||
|
|
||||||
|
test_online_fbank(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_fbank_htk():
|
def test_fbank_htk():
|
||||||
print("=====test_fbank_htk=====")
|
print("=====test_fbank_htk=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.FbankOptions()
|
opts = kaldifeat.FbankOptions()
|
||||||
@ -46,22 +104,32 @@ def test_fbank_htk():
|
|||||||
opts.htk_compat = True
|
opts.htk_compat = True
|
||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename)
|
|
||||||
|
|
||||||
features = fbank(wave)
|
features = fbank(wave)
|
||||||
assert features.device.type == "cpu"
|
assert features.device.type == "cpu"
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
|
|
||||||
assert torch.allclose(features, gt, rtol=1e-1)
|
assert torch.allclose(features, gt, rtol=1e-1)
|
||||||
|
if cpu_features is None:
|
||||||
|
cpu_features = features
|
||||||
|
|
||||||
wave = wave.to(device)
|
features = fbank(wave.to(device))
|
||||||
features = fbank(wave)
|
|
||||||
assert features.device == device
|
assert features.device == device
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.use_energy = True
|
||||||
|
opts.htk_compat = True
|
||||||
|
|
||||||
|
test_online_fbank(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_fbank_with_energy():
|
def test_fbank_with_energy():
|
||||||
print("=====test_fbank_with_energy=====")
|
print("=====test_fbank_with_energy=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.FbankOptions()
|
opts = kaldifeat.FbankOptions()
|
||||||
@ -70,22 +138,31 @@ def test_fbank_with_energy():
|
|||||||
opts.use_energy = True
|
opts.use_energy = True
|
||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename)
|
|
||||||
|
|
||||||
features = fbank(wave)
|
features = fbank(wave)
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
|
|
||||||
assert torch.allclose(features, gt, rtol=1e-1)
|
assert torch.allclose(features, gt, rtol=1e-1)
|
||||||
assert features.device.type == "cpu"
|
assert features.device.type == "cpu"
|
||||||
|
if cpu_features is None:
|
||||||
|
cpu_features = features
|
||||||
|
|
||||||
wave = wave.to(device)
|
features = fbank(wave.to(device))
|
||||||
features = fbank(wave)
|
|
||||||
assert features.device == device
|
assert features.device == device
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.use_energy = True
|
||||||
|
|
||||||
|
test_online_fbank(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_fbank_40_bins():
|
def test_fbank_40_bins():
|
||||||
print("=====test_fbank_40_bins=====")
|
print("=====test_fbank_40_bins=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.FbankOptions()
|
opts = kaldifeat.FbankOptions()
|
||||||
@ -94,22 +171,31 @@ def test_fbank_40_bins():
|
|||||||
opts.mel_opts.num_bins = 40
|
opts.mel_opts.num_bins = 40
|
||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename)
|
|
||||||
|
|
||||||
features = fbank(wave)
|
features = fbank(wave)
|
||||||
assert features.device.type == "cpu"
|
assert features.device.type == "cpu"
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
|
|
||||||
assert torch.allclose(features, gt, rtol=1e-1)
|
assert torch.allclose(features, gt, rtol=1e-1)
|
||||||
|
if cpu_features is None:
|
||||||
|
cpu_features = features
|
||||||
|
|
||||||
wave = wave.to(device)
|
features = fbank(wave.to(device))
|
||||||
features = fbank(wave)
|
|
||||||
assert features.device == device
|
assert features.device == device
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.mel_opts.num_bins = 40
|
||||||
|
|
||||||
|
test_online_fbank(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_fbank_40_bins_no_snip_edges():
|
def test_fbank_40_bins_no_snip_edges():
|
||||||
print("=====test_fbank_40_bins_no_snip_edges=====")
|
print("=====test_fbank_40_bins_no_snip_edges=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.FbankOptions()
|
opts = kaldifeat.FbankOptions()
|
||||||
@ -119,19 +205,24 @@ def test_fbank_40_bins_no_snip_edges():
|
|||||||
opts.frame_opts.snip_edges = False
|
opts.frame_opts.snip_edges = False
|
||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename)
|
|
||||||
|
|
||||||
features = fbank(wave)
|
features = fbank(wave)
|
||||||
assert features.device.type == "cpu"
|
assert features.device.type == "cpu"
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
|
|
||||||
assert torch.allclose(features, gt, rtol=1e-1)
|
assert torch.allclose(features, gt, rtol=1e-1)
|
||||||
|
if cpu_features is None:
|
||||||
|
cpu_features = features
|
||||||
|
|
||||||
wave = wave.to(device)
|
features = fbank(wave.to(device))
|
||||||
features = fbank(wave)
|
|
||||||
assert features.device == device
|
assert features.device == device
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.mel_opts.num_bins = 40
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
|
||||||
|
test_online_fbank(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_fbank_chunk():
|
def test_fbank_chunk():
|
||||||
print("=====test_fbank_chunk=====")
|
print("=====test_fbank_chunk=====")
|
||||||
@ -223,6 +314,16 @@ def test_pickle():
|
|||||||
|
|
||||||
assert str(fbank.opts) == str(fbank2.opts)
|
assert str(fbank.opts) == str(fbank2.opts)
|
||||||
|
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.use_energy = True
|
||||||
|
opts.use_power = False
|
||||||
|
|
||||||
|
fbank = kaldifeat.OnlineFbank(opts)
|
||||||
|
data = pickle.dumps(fbank)
|
||||||
|
fbank2 = pickle.loads(data)
|
||||||
|
|
||||||
|
assert str(fbank.opts) == str(fbank2.opts)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_fbank_default()
|
test_fbank_default()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user