support whisper v3 (#84)

This commit is contained in:
Fangjun Kuang 2023-11-09 12:45:56 +08:00 committed by GitHub
parent 20379449fc
commit 2624da8275
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 3834 additions and 17 deletions

View File

@ -7,7 +7,7 @@ project(kaldifeat)
# remember to change the version in # remember to change the version in
# scripts/conda/kaldifeat/meta.yaml # scripts/conda/kaldifeat/meta.yaml
# scripts/conda-cpu/kaldifeat/meta.yaml # scripts/conda-cpu/kaldifeat/meta.yaml
set(kaldifeat_VERSION "1.25.2") set(kaldifeat_VERSION "1.25.3")
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")

View File

@ -38,6 +38,21 @@ See <a href="https://github.com/csukuangfj/kaldifeat/pull/82">#82</a>
</td> </td>
</tr> </tr>
<tr>
<td>Fbank for <a href="https://github.com/openai/whisper">Whisper v3</a></td>
<td><code>kaldifeat.WhisperFbankOptions</code></td>
<td><code>kaldifeat.WhisperFbank</code></td>
<td>
<pre lang="python">
opts = kaldifeat.WhisperFbankOptions()
opts.num_mels = 128
opts.device = torch.device('cuda', 0)
fbank = kaldifeat.WhisperFbank(opts)
features = fbank(wave)
</pre>
</td>
</tr>
<tr> <tr>
<td>FBANK</td> <td>FBANK</td>
<td><code>kaldifeat.FbankOptions</code></td> <td><code>kaldifeat.FbankOptions</code></td>

View File

@ -1 +1 @@
exclude_files=whisper-mel-bank.h exclude_files=whisper-mel-bank.h,whisper-v3-mel-bank.h

View File

@ -0,0 +1,39 @@
#!/usr/bin/env python3
# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
import librosa
import numpy as np
def main():
m = librosa.filters.mel(sr=16000, n_fft=400, n_mels=128)
assert m.shape == (128, 201)
s = "// Auto-generated. Do NOT edit!\n\n"
s += "// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)\n\n"
s += "\n"
s += "#ifndef KALDIFEAT_CSRC_WHISPER_V3_MEL_BANK_H_\n"
s += "#define KALDIFEAT_CSRC_WHISPER_V3_MEL_BANK_H_\n"
s += "namespace kaldifeat {\n\n"
s += f"constexpr int32_t kWhisperV3MelRows = {m.shape[0]};\n"
s += f"constexpr int32_t kWhisperV3MelCols = {m.shape[1]};\n"
s += "\n"
s += "constexpr float kWhisperV3MelArray[] = {\n"
sep = ""
for i, f in enumerate(m.reshape(-1).tolist()):
s += f"{sep}{f:.8f}"
sep = ", "
if i and i % 7 == 0:
s += ",\n"
sep = ""
s += "};\n\n"
s += "} // namespace kaldifeat\n\n"
s += "#endif // KALDIFEAT_CSRC_WHISPER_V3_MEL_BANK_H_\n"
with open("whisper-v3-mel-bank.h", "w") as f:
f.write(s)
if __name__ == "__main__":
main()

View File

@ -23,6 +23,7 @@
#include "kaldifeat/csrc/mel-computations.h" #include "kaldifeat/csrc/mel-computations.h"
#include "kaldifeat/csrc/whisper-mel-bank.h" #include "kaldifeat/csrc/whisper-mel-bank.h"
#include "kaldifeat/csrc/whisper-v3-mel-bank.h"
#ifndef M_2PI #ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005 #define M_2PI 6.283185307179586476925286766559005
@ -31,9 +32,18 @@
namespace kaldifeat { namespace kaldifeat {
WhisperFbankComputer::WhisperFbankComputer(const WhisperFbankOptions &opts) WhisperFbankComputer::WhisperFbankComputer(const WhisperFbankOptions &opts)
: opts_(opts), : opts_(opts) {
mel_banks_(kWhisperMelArray, kWhisperMelRows, kWhisperMelCols, if (opts.num_mels == 80) {
opts.device) { mel_banks_ = std::make_unique<MelBanks>(kWhisperMelArray, kWhisperMelRows,
kWhisperMelCols, opts.device);
} else if (opts.num_mels == 128) {
mel_banks_ = std::make_unique<MelBanks>(
kWhisperV3MelArray, kWhisperV3MelRows, kWhisperV3MelCols, opts.device);
} else {
KALDIFEAT_ERR << "Unsupported num_mels: " << opts.num_mels
<< ". Support only 80 and 128";
}
opts_.frame_opts.samp_freq = 16000; opts_.frame_opts.samp_freq = 16000;
opts_.frame_opts.frame_shift_ms = 10; opts_.frame_opts.frame_shift_ms = 10;
opts_.frame_opts.frame_length_ms = 25; opts_.frame_opts.frame_length_ms = 25;
@ -67,7 +77,7 @@ torch::Tensor WhisperFbankComputer::Compute(
torch::Tensor power = (real.square() + imag.square()); torch::Tensor power = (real.square() + imag.square());
#endif #endif
torch::Tensor mel_energies = mel_banks_.Compute(power); torch::Tensor mel_energies = mel_banks_->Compute(power);
torch::Tensor log_spec = torch::clamp_min(mel_energies, 1e-10).log10(); torch::Tensor log_spec = torch::clamp_min(mel_energies, 1e-10).log10();
log_spec = torch::maximum(log_spec, log_spec.max() - 8.0); log_spec = torch::maximum(log_spec, log_spec.max() - 8.0);
torch::Tensor mel = (log_spec + 4.0) / 4.0; torch::Tensor mel = (log_spec + 4.0) / 4.0;

View File

@ -19,6 +19,7 @@
#ifndef KALDIFEAT_CSRC_WHISPER_FBANK_H_ #ifndef KALDIFEAT_CSRC_WHISPER_FBANK_H_
#define KALDIFEAT_CSRC_WHISPER_FBANK_H_ #define KALDIFEAT_CSRC_WHISPER_FBANK_H_
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
@ -30,12 +31,15 @@ namespace kaldifeat {
struct WhisperFbankOptions { struct WhisperFbankOptions {
FrameExtractionOptions frame_opts; FrameExtractionOptions frame_opts;
// for large v3, please use 128
int32_t num_mels = 80;
torch::Device device{"cpu"}; torch::Device device{"cpu"};
std::string ToString() const { std::string ToString() const {
std::ostringstream os; std::ostringstream os;
os << "WhisperFbankOptions("; os << "WhisperFbankOptions(";
os << "frame_opts=" << frame_opts.ToString() << ", "; os << "frame_opts=" << frame_opts.ToString() << ", ";
os << "num_mels=" << num_mels << ", ";
os << "device=\"" << device << "\")"; os << "device=\"" << device << "\")";
return os.str(); return os.str();
} }
@ -64,7 +68,7 @@ class WhisperFbankComputer {
private: private:
WhisperFbankOptions opts_; WhisperFbankOptions opts_;
MelBanks mel_banks_; std::unique_ptr<MelBanks> mel_banks_;
}; };
using WhisperFbank = OfflineFeatureTpl<WhisperFbankComputer>; using WhisperFbank = OfflineFeatureTpl<WhisperFbankComputer>;

File diff suppressed because it is too large Load Diff

View File

@ -23,7 +23,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
PybindFeatureWindow(m); PybindFeatureWindow(m);
PybindMelComputations(m); PybindMelComputations(m);
PybindFeatureFbank(m); PybindFeatureFbank(m);
PybindWhisperFbank(m); PybindWhisperFbank(&m);
PybindFeatureMfcc(m); PybindFeatureMfcc(m);
PybindFeaturePlp(m); PybindFeaturePlp(m);
PybindFeatureSpectrogram(m); PybindFeatureSpectrogram(m);

View File

@ -130,6 +130,8 @@ WhisperFbankOptions WhisperFbankOptionsFromDict(py::dict dict) {
opts.frame_opts = FrameExtractionOptionsFromDict(dict["frame_opts"]); opts.frame_opts = FrameExtractionOptionsFromDict(dict["frame_opts"]);
} }
FROM_DICT(int_, num_mels);
if (dict.contains("device")) { if (dict.contains("device")) {
opts.device = torch::Device(std::string(py::str(dict["device"]))); opts.device = torch::Device(std::string(py::str(dict["device"])));
} }
@ -142,6 +144,8 @@ py::dict AsDict(const WhisperFbankOptions &opts) {
dict["frame_opts"] = AsDict(opts.frame_opts); dict["frame_opts"] = AsDict(opts.frame_opts);
AS_DICT(num_mels);
auto torch_device = py::module_::import("torch").attr("device"); auto torch_device = py::module_::import("torch").attr("device");
dict["device"] = torch_device(opts.device.str()); dict["device"] = torch_device(opts.device.str());

View File

@ -12,16 +12,18 @@
namespace kaldifeat { namespace kaldifeat {
static void PybindWhisperFbankOptions(py::module &m) { static void PybindWhisperFbankOptions(py::module *m) {
using PyClass = WhisperFbankOptions; using PyClass = WhisperFbankOptions;
py::class_<PyClass>(m, "WhisperFbankOptions") py::class_<PyClass>(*m, "WhisperFbankOptions")
.def(py::init<>()) .def(py::init<>())
.def(py::init([](const FrameExtractionOptions &frame_opts = .def(py::init([](const FrameExtractionOptions &frame_opts =
FrameExtractionOptions(), FrameExtractionOptions(),
int32_t num_mels = 80,
py::object device = py::str( py::object device = py::str(
"cpu")) -> std::unique_ptr<WhisperFbankOptions> { "cpu")) -> std::unique_ptr<WhisperFbankOptions> {
auto opts = std::make_unique<WhisperFbankOptions>(); auto opts = std::make_unique<WhisperFbankOptions>();
opts->frame_opts = frame_opts; opts->frame_opts = frame_opts;
opts->num_mels = num_mels;
std::string s = static_cast<py::str>(device); std::string s = static_cast<py::str>(device);
opts->device = torch::Device(s); opts->device = torch::Device(s);
@ -29,8 +31,9 @@ static void PybindWhisperFbankOptions(py::module &m) {
return opts; return opts;
}), }),
py::arg("frame_opts") = FrameExtractionOptions(), py::arg("frame_opts") = FrameExtractionOptions(),
py::arg("device") = py::str("cpu")) py::arg("num_mels") = 80, py::arg("device") = py::str("cpu"))
.def_readwrite("frame_opts", &PyClass::frame_opts) .def_readwrite("frame_opts", &PyClass::frame_opts)
.def_readwrite("num_mels", &PyClass::num_mels)
.def_property( .def_property(
"device", "device",
[](const PyClass &self) -> py::object { [](const PyClass &self) -> py::object {
@ -56,9 +59,9 @@ static void PybindWhisperFbankOptions(py::module &m) {
})); }));
} }
static void PybindWhisperFbankImpl(py::module &m) { static void PybindWhisperFbankImpl(py::module *m) {
using PyClass = WhisperFbank; using PyClass = WhisperFbank;
py::class_<PyClass>(m, "WhisperFbank") py::class_<PyClass>(*m, "WhisperFbank")
.def(py::init<const WhisperFbankOptions &>(), py::arg("opts")) .def(py::init<const WhisperFbankOptions &>(), py::arg("opts"))
.def("dim", &PyClass::Dim) .def("dim", &PyClass::Dim)
.def_property_readonly("options", &PyClass::GetOptions) .def_property_readonly("options", &PyClass::GetOptions)
@ -73,7 +76,7 @@ static void PybindWhisperFbankImpl(py::module &m) {
})); }));
} }
void PybindWhisperFbank(py::module &m) { void PybindWhisperFbank(py::module *m) {
PybindWhisperFbankOptions(m); PybindWhisperFbankOptions(m);
PybindWhisperFbankImpl(m); PybindWhisperFbankImpl(m);
} }

View File

@ -9,7 +9,7 @@
namespace kaldifeat { namespace kaldifeat {
void PybindWhisperFbank(py::module &m); void PybindWhisperFbank(py::module *m);
} // namespace kaldifeat } // namespace kaldifeat

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang)
import librosa
import torch
import kaldifeat
def get_ground_truth(x):
N_FFT = 400
HOP_LENGTH = 160
m = librosa.filters.mel(sr=16000, n_fft=400, n_mels=128)
m = torch.from_numpy(m)
# print(m.shape) # [128, 201]
window = torch.hann_window(N_FFT)
stft = torch.stft(x, N_FFT, HOP_LENGTH, window=window, return_complex=True)
# print(stft.shape) # [201, 301]
magnitudes = stft[..., :-1].abs() ** 2
# print(magnitudes.shape) # [201, 300]
mel_spec = m @ magnitudes
# print(mel_spec.shape) # [128, 300]
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec.t()
def test_whisper_v3_fbank():
x = torch.rand(16000 * 3)
gt = get_ground_truth(x)
print(gt.shape) # [300, 128]
opts = kaldifeat.WhisperFbankOptions(num_mels=128, device="cpu")
print(opts)
whisper_fbank = kaldifeat.WhisperFbank(opts)
y = whisper_fbank(x) # [298, 128]
print(y.shape) # [298, 128]
print(gt[:5, :5])
print(y[:5, :5])
if __name__ == "__main__":
torch.manual_seed(20231109)
test_whisper_v3_fbank()

View File

@ -1,6 +1,6 @@
package: package:
name: kaldifeat name: kaldifeat
version: "1.25.2" version: "1.25.3"
source: source:
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}" path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"

View File

@ -1,6 +1,6 @@
package: package:
name: kaldifeat name: kaldifeat
version: "1.25.2" version: "1.25.3"
source: source:
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}" path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"