mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 01:52:39 +00:00
support whisper v3 (#84)
This commit is contained in:
parent
20379449fc
commit
2624da8275
@ -7,7 +7,7 @@ project(kaldifeat)
|
||||
# remember to change the version in
|
||||
# scripts/conda/kaldifeat/meta.yaml
|
||||
# scripts/conda-cpu/kaldifeat/meta.yaml
|
||||
set(kaldifeat_VERSION "1.25.2")
|
||||
set(kaldifeat_VERSION "1.25.3")
|
||||
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||
|
15
README.md
15
README.md
@ -38,6 +38,21 @@ See <a href="https://github.com/csukuangfj/kaldifeat/pull/82">#82</a>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Fbank for <a href="https://github.com/openai/whisper">Whisper v3</a></td>
|
||||
<td><code>kaldifeat.WhisperFbankOptions</code></td>
|
||||
<td><code>kaldifeat.WhisperFbank</code></td>
|
||||
<td>
|
||||
<pre lang="python">
|
||||
opts = kaldifeat.WhisperFbankOptions()
|
||||
opts.num_mels = 128
|
||||
opts.device = torch.device('cuda', 0)
|
||||
fbank = kaldifeat.WhisperFbank(opts)
|
||||
features = fbank(wave)
|
||||
</pre>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>FBANK</td>
|
||||
<td><code>kaldifeat.FbankOptions</code></td>
|
||||
|
@ -1 +1 @@
|
||||
exclude_files=whisper-mel-bank.h
|
||||
exclude_files=whisper-mel-bank.h,whisper-v3-mel-bank.h
|
||||
|
39
kaldifeat/csrc/generate-whisper-melbank-v3.py
Executable file
39
kaldifeat/csrc/generate-whisper-melbank-v3.py
Executable file
@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main():
|
||||
m = librosa.filters.mel(sr=16000, n_fft=400, n_mels=128)
|
||||
assert m.shape == (128, 201)
|
||||
s = "// Auto-generated. Do NOT edit!\n\n"
|
||||
s += "// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)\n\n"
|
||||
s += "\n"
|
||||
s += "#ifndef KALDIFEAT_CSRC_WHISPER_V3_MEL_BANK_H_\n"
|
||||
s += "#define KALDIFEAT_CSRC_WHISPER_V3_MEL_BANK_H_\n"
|
||||
s += "namespace kaldifeat {\n\n"
|
||||
s += f"constexpr int32_t kWhisperV3MelRows = {m.shape[0]};\n"
|
||||
s += f"constexpr int32_t kWhisperV3MelCols = {m.shape[1]};\n"
|
||||
s += "\n"
|
||||
s += "constexpr float kWhisperV3MelArray[] = {\n"
|
||||
sep = ""
|
||||
for i, f in enumerate(m.reshape(-1).tolist()):
|
||||
s += f"{sep}{f:.8f}"
|
||||
sep = ", "
|
||||
if i and i % 7 == 0:
|
||||
s += ",\n"
|
||||
sep = ""
|
||||
|
||||
s += "};\n\n"
|
||||
s += "} // namespace kaldifeat\n\n"
|
||||
s += "#endif // KALDIFEAT_CSRC_WHISPER_V3_MEL_BANK_H_\n"
|
||||
|
||||
with open("whisper-v3-mel-bank.h", "w") as f:
|
||||
f.write(s)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -23,6 +23,7 @@
|
||||
|
||||
#include "kaldifeat/csrc/mel-computations.h"
|
||||
#include "kaldifeat/csrc/whisper-mel-bank.h"
|
||||
#include "kaldifeat/csrc/whisper-v3-mel-bank.h"
|
||||
|
||||
#ifndef M_2PI
|
||||
#define M_2PI 6.283185307179586476925286766559005
|
||||
@ -31,9 +32,18 @@
|
||||
namespace kaldifeat {
|
||||
|
||||
WhisperFbankComputer::WhisperFbankComputer(const WhisperFbankOptions &opts)
|
||||
: opts_(opts),
|
||||
mel_banks_(kWhisperMelArray, kWhisperMelRows, kWhisperMelCols,
|
||||
opts.device) {
|
||||
: opts_(opts) {
|
||||
if (opts.num_mels == 80) {
|
||||
mel_banks_ = std::make_unique<MelBanks>(kWhisperMelArray, kWhisperMelRows,
|
||||
kWhisperMelCols, opts.device);
|
||||
} else if (opts.num_mels == 128) {
|
||||
mel_banks_ = std::make_unique<MelBanks>(
|
||||
kWhisperV3MelArray, kWhisperV3MelRows, kWhisperV3MelCols, opts.device);
|
||||
} else {
|
||||
KALDIFEAT_ERR << "Unsupported num_mels: " << opts.num_mels
|
||||
<< ". Support only 80 and 128";
|
||||
}
|
||||
|
||||
opts_.frame_opts.samp_freq = 16000;
|
||||
opts_.frame_opts.frame_shift_ms = 10;
|
||||
opts_.frame_opts.frame_length_ms = 25;
|
||||
@ -67,7 +77,7 @@ torch::Tensor WhisperFbankComputer::Compute(
|
||||
torch::Tensor power = (real.square() + imag.square());
|
||||
#endif
|
||||
|
||||
torch::Tensor mel_energies = mel_banks_.Compute(power);
|
||||
torch::Tensor mel_energies = mel_banks_->Compute(power);
|
||||
torch::Tensor log_spec = torch::clamp_min(mel_energies, 1e-10).log10();
|
||||
log_spec = torch::maximum(log_spec, log_spec.max() - 8.0);
|
||||
torch::Tensor mel = (log_spec + 4.0) / 4.0;
|
||||
|
@ -19,6 +19,7 @@
|
||||
#ifndef KALDIFEAT_CSRC_WHISPER_FBANK_H_
|
||||
#define KALDIFEAT_CSRC_WHISPER_FBANK_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -30,12 +31,15 @@ namespace kaldifeat {
|
||||
|
||||
struct WhisperFbankOptions {
|
||||
FrameExtractionOptions frame_opts;
|
||||
// for large v3, please use 128
|
||||
int32_t num_mels = 80;
|
||||
|
||||
torch::Device device{"cpu"};
|
||||
std::string ToString() const {
|
||||
std::ostringstream os;
|
||||
os << "WhisperFbankOptions(";
|
||||
os << "frame_opts=" << frame_opts.ToString() << ", ";
|
||||
os << "num_mels=" << num_mels << ", ";
|
||||
os << "device=\"" << device << "\")";
|
||||
return os.str();
|
||||
}
|
||||
@ -64,7 +68,7 @@ class WhisperFbankComputer {
|
||||
|
||||
private:
|
||||
WhisperFbankOptions opts_;
|
||||
MelBanks mel_banks_;
|
||||
std::unique_ptr<MelBanks> mel_banks_;
|
||||
};
|
||||
|
||||
using WhisperFbank = OfflineFeatureTpl<WhisperFbankComputer>;
|
||||
|
3693
kaldifeat/csrc/whisper-v3-mel-bank.h
Normal file
3693
kaldifeat/csrc/whisper-v3-mel-bank.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -23,7 +23,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
|
||||
PybindFeatureWindow(m);
|
||||
PybindMelComputations(m);
|
||||
PybindFeatureFbank(m);
|
||||
PybindWhisperFbank(m);
|
||||
PybindWhisperFbank(&m);
|
||||
PybindFeatureMfcc(m);
|
||||
PybindFeaturePlp(m);
|
||||
PybindFeatureSpectrogram(m);
|
||||
|
@ -130,6 +130,8 @@ WhisperFbankOptions WhisperFbankOptionsFromDict(py::dict dict) {
|
||||
opts.frame_opts = FrameExtractionOptionsFromDict(dict["frame_opts"]);
|
||||
}
|
||||
|
||||
FROM_DICT(int_, num_mels);
|
||||
|
||||
if (dict.contains("device")) {
|
||||
opts.device = torch::Device(std::string(py::str(dict["device"])));
|
||||
}
|
||||
@ -142,6 +144,8 @@ py::dict AsDict(const WhisperFbankOptions &opts) {
|
||||
|
||||
dict["frame_opts"] = AsDict(opts.frame_opts);
|
||||
|
||||
AS_DICT(num_mels);
|
||||
|
||||
auto torch_device = py::module_::import("torch").attr("device");
|
||||
dict["device"] = torch_device(opts.device.str());
|
||||
|
||||
|
@ -12,16 +12,18 @@
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
static void PybindWhisperFbankOptions(py::module &m) {
|
||||
static void PybindWhisperFbankOptions(py::module *m) {
|
||||
using PyClass = WhisperFbankOptions;
|
||||
py::class_<PyClass>(m, "WhisperFbankOptions")
|
||||
py::class_<PyClass>(*m, "WhisperFbankOptions")
|
||||
.def(py::init<>())
|
||||
.def(py::init([](const FrameExtractionOptions &frame_opts =
|
||||
FrameExtractionOptions(),
|
||||
int32_t num_mels = 80,
|
||||
py::object device = py::str(
|
||||
"cpu")) -> std::unique_ptr<WhisperFbankOptions> {
|
||||
auto opts = std::make_unique<WhisperFbankOptions>();
|
||||
opts->frame_opts = frame_opts;
|
||||
opts->num_mels = num_mels;
|
||||
|
||||
std::string s = static_cast<py::str>(device);
|
||||
opts->device = torch::Device(s);
|
||||
@ -29,8 +31,9 @@ static void PybindWhisperFbankOptions(py::module &m) {
|
||||
return opts;
|
||||
}),
|
||||
py::arg("frame_opts") = FrameExtractionOptions(),
|
||||
py::arg("device") = py::str("cpu"))
|
||||
py::arg("num_mels") = 80, py::arg("device") = py::str("cpu"))
|
||||
.def_readwrite("frame_opts", &PyClass::frame_opts)
|
||||
.def_readwrite("num_mels", &PyClass::num_mels)
|
||||
.def_property(
|
||||
"device",
|
||||
[](const PyClass &self) -> py::object {
|
||||
@ -56,9 +59,9 @@ static void PybindWhisperFbankOptions(py::module &m) {
|
||||
}));
|
||||
}
|
||||
|
||||
static void PybindWhisperFbankImpl(py::module &m) {
|
||||
static void PybindWhisperFbankImpl(py::module *m) {
|
||||
using PyClass = WhisperFbank;
|
||||
py::class_<PyClass>(m, "WhisperFbank")
|
||||
py::class_<PyClass>(*m, "WhisperFbank")
|
||||
.def(py::init<const WhisperFbankOptions &>(), py::arg("opts"))
|
||||
.def("dim", &PyClass::Dim)
|
||||
.def_property_readonly("options", &PyClass::GetOptions)
|
||||
@ -73,7 +76,7 @@ static void PybindWhisperFbankImpl(py::module &m) {
|
||||
}));
|
||||
}
|
||||
|
||||
void PybindWhisperFbank(py::module &m) {
|
||||
void PybindWhisperFbank(py::module *m) {
|
||||
PybindWhisperFbankOptions(m);
|
||||
PybindWhisperFbankImpl(m);
|
||||
}
|
||||
|
@ -9,7 +9,7 @@
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
void PybindWhisperFbank(py::module &m);
|
||||
void PybindWhisperFbank(py::module *m);
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
|
49
kaldifeat/python/tests/test_whisper_v3_fbank.py
Normal file
49
kaldifeat/python/tests/test_whisper_v3_fbank.py
Normal file
@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
import kaldifeat
|
||||
|
||||
|
||||
def get_ground_truth(x):
|
||||
N_FFT = 400
|
||||
HOP_LENGTH = 160
|
||||
|
||||
m = librosa.filters.mel(sr=16000, n_fft=400, n_mels=128)
|
||||
m = torch.from_numpy(m)
|
||||
# print(m.shape) # [128, 201]
|
||||
window = torch.hann_window(N_FFT)
|
||||
stft = torch.stft(x, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||
# print(stft.shape) # [201, 301]
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
# print(magnitudes.shape) # [201, 300]
|
||||
|
||||
mel_spec = m @ magnitudes
|
||||
# print(mel_spec.shape) # [128, 300]
|
||||
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec.t()
|
||||
|
||||
|
||||
def test_whisper_v3_fbank():
|
||||
x = torch.rand(16000 * 3)
|
||||
|
||||
gt = get_ground_truth(x)
|
||||
print(gt.shape) # [300, 128]
|
||||
|
||||
opts = kaldifeat.WhisperFbankOptions(num_mels=128, device="cpu")
|
||||
print(opts)
|
||||
whisper_fbank = kaldifeat.WhisperFbank(opts)
|
||||
y = whisper_fbank(x) # [298, 128]
|
||||
print(y.shape) # [298, 128]
|
||||
|
||||
print(gt[:5, :5])
|
||||
print(y[:5, :5])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20231109)
|
||||
test_whisper_v3_fbank()
|
@ -1,6 +1,6 @@
|
||||
package:
|
||||
name: kaldifeat
|
||||
version: "1.25.2"
|
||||
version: "1.25.3"
|
||||
|
||||
source:
|
||||
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"
|
||||
|
@ -1,6 +1,6 @@
|
||||
package:
|
||||
name: kaldifeat
|
||||
version: "1.25.2"
|
||||
version: "1.25.3"
|
||||
|
||||
source:
|
||||
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user