mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 10:02:20 +00:00
Support computing features for whisper (#82)
This commit is contained in:
parent
7912c2f442
commit
01aed93b1b
@ -7,7 +7,7 @@ project(kaldifeat)
|
||||
# remember to change the version in
|
||||
# scripts/conda/kaldifeat/meta.yaml
|
||||
# scripts/conda-cpu/kaldifeat/meta.yaml
|
||||
set(kaldifeat_VERSION "1.25.1")
|
||||
set(kaldifeat_VERSION "1.25.2")
|
||||
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||
|
@ -10,6 +10,7 @@ set(kaldifeat_srcs
|
||||
matrix-functions.cc
|
||||
mel-computations.cc
|
||||
online-feature.cc
|
||||
whisper-fbank.cc
|
||||
)
|
||||
|
||||
add_library(kaldifeat_core ${kaldifeat_srcs})
|
||||
|
1
kaldifeat/csrc/CPPLINT.cfg
Normal file
1
kaldifeat/csrc/CPPLINT.cfg
Normal file
@ -0,0 +1 @@
|
||||
exclude_files=whisper-mel-bank.h
|
@ -65,7 +65,7 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
|
||||
// note spectrum is in magnitude, not power, because of `abs()`
|
||||
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
|
||||
// signal_frame shape: [x, 512]
|
||||
// spectrum shape [x, 257
|
||||
// spectrum shape [x, 257]
|
||||
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
|
||||
#else
|
||||
// signal_frame shape [x, 512]
|
||||
|
@ -29,6 +29,13 @@ FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts,
|
||||
float *window_data = window.data_ptr<float>();
|
||||
|
||||
double a = M_2PI / (frame_length - 1);
|
||||
|
||||
if (opts.window_type == "hann") {
|
||||
// see https://pytorch.org/docs/stable/generated/torch.hann_window.html
|
||||
// We assume periodic is true
|
||||
a = M_2PI / frame_length;
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < frame_length; i++) {
|
||||
double i_fl = static_cast<double>(i);
|
||||
if (opts.window_type == "hanning") {
|
||||
@ -39,6 +46,8 @@ FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts,
|
||||
window_data[i] = sin(0.5 * a * i_fl);
|
||||
} else if (opts.window_type == "hamming") {
|
||||
window_data[i] = 0.54 - 0.46 * cos(a * i_fl);
|
||||
} else if (opts.window_type == "hann") {
|
||||
window_data[i] = 0.50 - 0.50 * cos(a * i_fl);
|
||||
} else if (opts.window_type ==
|
||||
"povey") { // like hamming but goes to zero at edges.
|
||||
window_data[i] = pow(0.5 - 0.5 * cos(a * i_fl), 0.85);
|
||||
|
39
kaldifeat/csrc/generate-whisper-melbank.py
Executable file
39
kaldifeat/csrc/generate-whisper-melbank.py
Executable file
@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main():
|
||||
m = librosa.filters.mel(sr=16000, n_fft=400, n_mels=80)
|
||||
assert m.shape == (80, 201)
|
||||
s = "// Auto-generated. Do NOT edit!\n\n"
|
||||
s += "// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)\n\n"
|
||||
s += "\n"
|
||||
s += "#ifndef KALDIFEAT_CSRC_WHISPER_MEL_BANK_H_\n"
|
||||
s += "#define KALDIFEAT_CSRC_WHISPER_MEL_BANK_H_\n"
|
||||
s += "namespace kaldifeat {\n\n"
|
||||
s += f"constexpr int32_t kWhisperMelRows = {m.shape[0]};\n"
|
||||
s += f"constexpr int32_t kWhisperMelCols = {m.shape[1]};\n"
|
||||
s += "\n"
|
||||
s += "constexpr float kWhisperMelArray[] = {\n"
|
||||
sep = ""
|
||||
for i, f in enumerate(m.reshape(-1).tolist()):
|
||||
s += f"{sep}{f:.8f}"
|
||||
sep = ", "
|
||||
if i and i % 7 == 0:
|
||||
s += ",\n"
|
||||
sep = ""
|
||||
|
||||
s += "};\n\n"
|
||||
s += "} // namespace kaldifeat\n\n"
|
||||
s += "#endif // KALDIFEAT_CSRC_WHISPER_MEL_BANK_H_\n"
|
||||
|
||||
with open("whisper-mel-bank.h", "w") as f:
|
||||
f.write(s)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -196,6 +196,15 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
|
||||
}
|
||||
}
|
||||
|
||||
MelBanks::MelBanks(const float *weights, int32_t num_rows, int32_t num_cols,
|
||||
torch::Device device)
|
||||
: debug_(false), htk_mode_(false) {
|
||||
bins_mat_ = torch::from_blob(const_cast<float *>(weights),
|
||||
{num_rows, num_cols}, torch::kFloat)
|
||||
.t()
|
||||
.to(device);
|
||||
}
|
||||
|
||||
torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const {
|
||||
return torch::mm(spectrum, bins_mat_);
|
||||
}
|
||||
|
@ -76,6 +76,17 @@ class MelBanks {
|
||||
const FrameExtractionOptions &frame_opts, float vtln_warp_factor,
|
||||
torch::Device device);
|
||||
|
||||
// Initialize with a 2-d weights matrix
|
||||
//
|
||||
// Note: This constructor is for Whisper. It does not initialize
|
||||
// center_freqs_.
|
||||
//
|
||||
// @param weights Pointer to the start address of the matrix
|
||||
// @param num_rows It equals to number of mel bins
|
||||
// @param num_cols It equals to (number of fft bins)/2+1
|
||||
MelBanks(const float *weights, int32_t num_rows, int32_t num_cols,
|
||||
torch::Device device);
|
||||
|
||||
// CAUTION: we save a transposed version of bins_mat_, so return size(1) here
|
||||
int32_t NumBins() const { return static_cast<int32_t>(bins_mat_.size(1)); }
|
||||
|
||||
@ -89,7 +100,8 @@ class MelBanks {
|
||||
|
||||
private:
|
||||
// A 2-D matrix. Its shape is NOT [num_bins, num_fft_bins]
|
||||
// Its shape is [num_fft_bins, num_bins].
|
||||
// Its shape is [num_fft_bins, num_bins] for non-whisper.
|
||||
// For whisper, its shape is [num_fft_bins/2+1, num_bins]
|
||||
torch::Tensor bins_mat_;
|
||||
|
||||
// center frequencies of bins, numbered from 0 ... num_bins-1.
|
||||
|
78
kaldifeat/csrc/whisper-fbank.cc
Normal file
78
kaldifeat/csrc/whisper-fbank.cc
Normal file
@ -0,0 +1,78 @@
|
||||
/**
|
||||
* Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
*
|
||||
* See LICENSE for clarification regarding multiple authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kaldifeat/csrc/whisper-fbank.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "kaldifeat/csrc/mel-computations.h"
|
||||
#include "kaldifeat/csrc/whisper-mel-bank.h"
|
||||
|
||||
#ifndef M_2PI
|
||||
#define M_2PI 6.283185307179586476925286766559005
|
||||
#endif
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
WhisperFbankComputer::WhisperFbankComputer(const WhisperFbankOptions &opts)
|
||||
: opts_(opts),
|
||||
mel_banks_(kWhisperMelArray, kWhisperMelRows, kWhisperMelCols,
|
||||
opts.device) {
|
||||
opts_.frame_opts.samp_freq = 16000;
|
||||
opts_.frame_opts.frame_shift_ms = 10;
|
||||
opts_.frame_opts.frame_length_ms = 25;
|
||||
opts_.frame_opts.dither = 0;
|
||||
opts_.frame_opts.preemph_coeff = 0;
|
||||
opts_.frame_opts.remove_dc_offset = false;
|
||||
opts_.frame_opts.window_type = "hann";
|
||||
opts_.frame_opts.round_to_power_of_two = false;
|
||||
opts_.frame_opts.snip_edges = false;
|
||||
}
|
||||
|
||||
torch::Tensor WhisperFbankComputer::Compute(
|
||||
torch::Tensor /*signal_raw_log_energy*/, float /*vtln_warp*/,
|
||||
const torch::Tensor &signal_frame) {
|
||||
KALDIFEAT_ASSERT(signal_frame.dim() == 2);
|
||||
KALDIFEAT_ASSERT(signal_frame.size(1) == opts_.frame_opts.PaddedWindowSize());
|
||||
|
||||
// note spectrum is in magnitude, not power, because of `abs()`
|
||||
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
|
||||
// signal_frame shape: [x, 512]
|
||||
// power shape [x, 257]
|
||||
torch::Tensor power = torch::fft::rfft(signal_frame).abs().pow(2);
|
||||
#else
|
||||
// signal_frame shape [x, 512]
|
||||
// real_imag shape [x, 257, 2],
|
||||
// where [..., 0] is the real part
|
||||
// [..., 1] is the imaginary part
|
||||
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
|
||||
torch::Tensor real = real_imag.index({"...", 0});
|
||||
torch::Tensor imag = real_imag.index({"...", 1});
|
||||
torch::Tensor power = (real.square() + imag.square());
|
||||
#endif
|
||||
|
||||
torch::Tensor mel_energies = mel_banks_.Compute(power);
|
||||
torch::Tensor log_spec = torch::clamp_min(mel_energies, 1e-10).log10();
|
||||
log_spec = torch::maximum(log_spec, log_spec.max() - 8.0);
|
||||
torch::Tensor mel = (log_spec + 4.0) / 4.0;
|
||||
|
||||
return mel;
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
74
kaldifeat/csrc/whisper-fbank.h
Normal file
74
kaldifeat/csrc/whisper-fbank.h
Normal file
@ -0,0 +1,74 @@
|
||||
/**
|
||||
* Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
*
|
||||
* See LICENSE for clarification regarding multiple authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef KALDIFEAT_CSRC_WHISPER_FBANK_H_
|
||||
#define KALDIFEAT_CSRC_WHISPER_FBANK_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "kaldifeat/csrc/feature-common.h"
|
||||
#include "kaldifeat/csrc/feature-window.h"
|
||||
#include "kaldifeat/csrc/mel-computations.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
struct WhisperFbankOptions {
|
||||
FrameExtractionOptions frame_opts;
|
||||
|
||||
torch::Device device{"cpu"};
|
||||
std::string ToString() const {
|
||||
std::ostringstream os;
|
||||
os << "WhisperFbankOptions(";
|
||||
os << "frame_opts=" << frame_opts.ToString() << ", ";
|
||||
os << "device=\"" << device << "\")";
|
||||
return os.str();
|
||||
}
|
||||
};
|
||||
|
||||
class WhisperFbankComputer {
|
||||
public:
|
||||
// note: Only frame_opts.device is used. All other fields from frame_opts
|
||||
// are ignored
|
||||
explicit WhisperFbankComputer(const WhisperFbankOptions &opts = {});
|
||||
|
||||
int32_t Dim() const { return 80; }
|
||||
|
||||
const FrameExtractionOptions &GetFrameOptions() const {
|
||||
return opts_.frame_opts;
|
||||
}
|
||||
|
||||
const WhisperFbankOptions &GetOptions() const { return opts_; }
|
||||
|
||||
torch::Tensor Compute(torch::Tensor /*signal_raw_log_energy*/,
|
||||
float /*vtln_warp*/, const torch::Tensor &signal_frame);
|
||||
|
||||
// if true, compute log_energy_pre_window but after dithering and dc removal
|
||||
bool NeedRawLogEnergy() const { return false; }
|
||||
using Options = WhisperFbankOptions;
|
||||
|
||||
private:
|
||||
WhisperFbankOptions opts_;
|
||||
MelBanks mel_banks_;
|
||||
};
|
||||
|
||||
using WhisperFbank = OfflineFeatureTpl<WhisperFbankComputer>;
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_CSRC_WHISPER_FBANK_H_
|
2315
kaldifeat/csrc/whisper-mel-bank.h
Normal file
2315
kaldifeat/csrc/whisper-mel-bank.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -9,6 +9,7 @@ pybind11_add_module(_kaldifeat
|
||||
mel-computations.cc
|
||||
online-feature.cc
|
||||
utils.cc
|
||||
whisper-fbank.cc
|
||||
)
|
||||
|
||||
if(APPLE)
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "kaldifeat/python/csrc/feature-window.h"
|
||||
#include "kaldifeat/python/csrc/mel-computations.h"
|
||||
#include "kaldifeat/python/csrc/online-feature.h"
|
||||
#include "kaldifeat/python/csrc/whisper-fbank.h"
|
||||
#include "torch/torch.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
@ -22,6 +23,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
|
||||
PybindFeatureWindow(m);
|
||||
PybindMelComputations(m);
|
||||
PybindFeatureFbank(m);
|
||||
PybindWhisperFbank(m);
|
||||
PybindFeatureMfcc(m);
|
||||
PybindFeaturePlp(m);
|
||||
PybindFeatureSpectrogram(m);
|
||||
|
@ -123,6 +123,31 @@ py::dict AsDict(const FbankOptions &opts) {
|
||||
return dict;
|
||||
}
|
||||
|
||||
WhisperFbankOptions WhisperFbankOptionsFromDict(py::dict dict) {
|
||||
WhisperFbankOptions opts;
|
||||
|
||||
if (dict.contains("frame_opts")) {
|
||||
opts.frame_opts = FrameExtractionOptionsFromDict(dict["frame_opts"]);
|
||||
}
|
||||
|
||||
if (dict.contains("device")) {
|
||||
opts.device = torch::Device(std::string(py::str(dict["device"])));
|
||||
}
|
||||
|
||||
return opts;
|
||||
}
|
||||
|
||||
py::dict AsDict(const WhisperFbankOptions &opts) {
|
||||
py::dict dict;
|
||||
|
||||
dict["frame_opts"] = AsDict(opts.frame_opts);
|
||||
|
||||
auto torch_device = py::module_::import("torch").attr("device");
|
||||
dict["device"] = torch_device(opts.device.str());
|
||||
|
||||
return dict;
|
||||
}
|
||||
|
||||
MfccOptions MfccOptionsFromDict(py::dict dict) {
|
||||
MfccOptions opts;
|
||||
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "kaldifeat/csrc/feature-spectrogram.h"
|
||||
#include "kaldifeat/csrc/feature-window.h"
|
||||
#include "kaldifeat/csrc/mel-computations.h"
|
||||
#include "kaldifeat/csrc/whisper-fbank.h"
|
||||
#include "kaldifeat/python/csrc/kaldifeat.h"
|
||||
|
||||
/*
|
||||
@ -36,6 +37,9 @@ py::dict AsDict(const MelBanksOptions &opts);
|
||||
FbankOptions FbankOptionsFromDict(py::dict dict);
|
||||
py::dict AsDict(const FbankOptions &opts);
|
||||
|
||||
WhisperFbankOptions WhisperFbankOptionsFromDict(py::dict dict);
|
||||
py::dict AsDict(const WhisperFbankOptions &opts);
|
||||
|
||||
MfccOptions MfccOptionsFromDict(py::dict dict);
|
||||
py::dict AsDict(const MfccOptions &opts);
|
||||
|
||||
|
81
kaldifeat/python/csrc/whisper-fbank.cc
Normal file
81
kaldifeat/python/csrc/whisper-fbank.cc
Normal file
@ -0,0 +1,81 @@
|
||||
// kaldifeat/python/csrc/whisper-fbank.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
#include "kaldifeat/python/csrc/whisper-fbank.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "kaldifeat/csrc/whisper-fbank.h"
|
||||
#include "kaldifeat/python/csrc/utils.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
static void PybindWhisperFbankOptions(py::module &m) {
|
||||
using PyClass = WhisperFbankOptions;
|
||||
py::class_<PyClass>(m, "WhisperFbankOptions")
|
||||
.def(py::init<>())
|
||||
.def(py::init([](const FrameExtractionOptions &frame_opts =
|
||||
FrameExtractionOptions(),
|
||||
py::object device = py::str(
|
||||
"cpu")) -> std::unique_ptr<WhisperFbankOptions> {
|
||||
auto opts = std::make_unique<WhisperFbankOptions>();
|
||||
opts->frame_opts = frame_opts;
|
||||
|
||||
std::string s = static_cast<py::str>(device);
|
||||
opts->device = torch::Device(s);
|
||||
|
||||
return opts;
|
||||
}),
|
||||
py::arg("frame_opts") = FrameExtractionOptions(),
|
||||
py::arg("device") = py::str("cpu"))
|
||||
.def_readwrite("frame_opts", &PyClass::frame_opts)
|
||||
.def_property(
|
||||
"device",
|
||||
[](const PyClass &self) -> py::object {
|
||||
py::object ans = py::module_::import("torch").attr("device");
|
||||
return ans(self.device.str());
|
||||
},
|
||||
[](PyClass &self, py::object obj) -> void {
|
||||
std::string s = static_cast<py::str>(obj);
|
||||
self.device = torch::Device(s);
|
||||
})
|
||||
.def("__str__",
|
||||
[](const PyClass &self) -> std::string { return self.ToString(); })
|
||||
.def("as_dict",
|
||||
[](const PyClass &self) -> py::dict { return AsDict(self); })
|
||||
.def_static("from_dict",
|
||||
[](py::dict dict) -> PyClass {
|
||||
return WhisperFbankOptionsFromDict(dict);
|
||||
})
|
||||
.def(py::pickle(
|
||||
[](const PyClass &self) -> py::dict { return AsDict(self); },
|
||||
[](py::dict dict) -> PyClass {
|
||||
return WhisperFbankOptionsFromDict(dict);
|
||||
}));
|
||||
}
|
||||
|
||||
static void PybindWhisperFbankImpl(py::module &m) {
|
||||
using PyClass = WhisperFbank;
|
||||
py::class_<PyClass>(m, "WhisperFbank")
|
||||
.def(py::init<const WhisperFbankOptions &>(), py::arg("opts"))
|
||||
.def("dim", &PyClass::Dim)
|
||||
.def_property_readonly("options", &PyClass::GetOptions)
|
||||
.def("compute_features", &PyClass::ComputeFeatures, py::arg("wave"),
|
||||
py::arg("vtln_warp"), py::call_guard<py::gil_scoped_release>())
|
||||
.def(py::pickle(
|
||||
[](const PyClass &self) -> py::dict {
|
||||
return AsDict(self.GetOptions());
|
||||
},
|
||||
[](py::dict dict) -> std::unique_ptr<PyClass> {
|
||||
return std::make_unique<PyClass>(WhisperFbankOptionsFromDict(dict));
|
||||
}));
|
||||
}
|
||||
|
||||
void PybindWhisperFbank(py::module &m) {
|
||||
PybindWhisperFbankOptions(m);
|
||||
PybindWhisperFbankImpl(m);
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
16
kaldifeat/python/csrc/whisper-fbank.h
Normal file
16
kaldifeat/python/csrc/whisper-fbank.h
Normal file
@ -0,0 +1,16 @@
|
||||
// kaldifeat/python/csrc/whisper-fbank.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
#ifndef KALDIFEAT_PYTHON_CSRC_WHISPER_FBANK_H_
|
||||
#define KALDIFEAT_PYTHON_CSRC_WHISPER_FBANK_H_
|
||||
|
||||
#include "kaldifeat/python/csrc/kaldifeat.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
void PybindWhisperFbank(py::module &m);
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_PYTHON_CSRC_WHISPER_FBANK_H_
|
@ -17,6 +17,7 @@ from _kaldifeat import (
|
||||
MfccOptions,
|
||||
PlpOptions,
|
||||
SpectrogramOptions,
|
||||
WhisperFbankOptions,
|
||||
num_frames,
|
||||
)
|
||||
|
||||
@ -26,6 +27,7 @@ from .offline_feature import OfflineFeature
|
||||
from .online_feature import OnlineFeature
|
||||
from .plp import OnlinePlp, Plp
|
||||
from .spectrogram import Spectrogram
|
||||
from .whisper_fbank import WhisperFbank
|
||||
|
||||
cmake_prefix_path = _Path(__file__).parent / "share" / "cmake"
|
||||
del _Path
|
||||
|
12
kaldifeat/python/kaldifeat/whisper_fbank.py
Normal file
12
kaldifeat/python/kaldifeat/whisper_fbank.py
Normal file
@ -0,0 +1,12 @@
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
|
||||
import _kaldifeat
|
||||
|
||||
from .offline_feature import OfflineFeature
|
||||
|
||||
|
||||
class WhisperFbank(OfflineFeature):
|
||||
def __init__(self, opts: _kaldifeat.WhisperFbankOptions):
|
||||
super().__init__(opts)
|
||||
self.computer = _kaldifeat.WhisperFbank(opts)
|
48
kaldifeat/python/tests/test_whisper_fbank.py
Normal file
48
kaldifeat/python/tests/test_whisper_fbank.py
Normal file
@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
import kaldifeat
|
||||
|
||||
|
||||
def get_ground_truth(x):
|
||||
N_FFT = 400
|
||||
HOP_LENGTH = 160
|
||||
|
||||
m = librosa.filters.mel(sr=16000, n_fft=400, n_mels=80)
|
||||
m = torch.from_numpy(m)
|
||||
# print(m.shape) # [80, 201]
|
||||
window = torch.hann_window(N_FFT)
|
||||
stft = torch.stft(x, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||
# print(stft.shape) # [201, 301]
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
# print(magnitudes.shape) # [201, 300]
|
||||
|
||||
mel_spec = m @ magnitudes
|
||||
# print(mel_spec.shape) # [80, 300]
|
||||
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec.t()
|
||||
|
||||
|
||||
def test_whisper_fbank():
|
||||
x = torch.rand(16000 * 3)
|
||||
|
||||
gt = get_ground_truth(x)
|
||||
print(gt.shape) # [300, 80]
|
||||
|
||||
opts = kaldifeat.WhisperFbankOptions(device="cpu")
|
||||
whisper_fbank = kaldifeat.WhisperFbank(opts)
|
||||
y = whisper_fbank(x) # [298, 80]
|
||||
print(y.shape) # [298, 80]
|
||||
|
||||
# print(gt[:5, :5])
|
||||
# print(y[:5, :5])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20231108)
|
||||
test_whisper_fbank()
|
@ -1,6 +1,6 @@
|
||||
package:
|
||||
name: kaldifeat
|
||||
version: "1.25.1"
|
||||
version: "1.25.2"
|
||||
|
||||
source:
|
||||
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"
|
||||
|
@ -1,6 +1,6 @@
|
||||
package:
|
||||
name: kaldifeat
|
||||
version: "1.25.1"
|
||||
version: "1.25.2"
|
||||
|
||||
source:
|
||||
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user