mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 10:02:20 +00:00
support whisper v3 (#84)
This commit is contained in:
parent
20379449fc
commit
2624da8275
@ -7,7 +7,7 @@ project(kaldifeat)
|
|||||||
# remember to change the version in
|
# remember to change the version in
|
||||||
# scripts/conda/kaldifeat/meta.yaml
|
# scripts/conda/kaldifeat/meta.yaml
|
||||||
# scripts/conda-cpu/kaldifeat/meta.yaml
|
# scripts/conda-cpu/kaldifeat/meta.yaml
|
||||||
set(kaldifeat_VERSION "1.25.2")
|
set(kaldifeat_VERSION "1.25.3")
|
||||||
|
|
||||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||||
|
15
README.md
15
README.md
@ -38,6 +38,21 @@ See <a href="https://github.com/csukuangfj/kaldifeat/pull/82">#82</a>
|
|||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
|
<tr>
|
||||||
|
<td>Fbank for <a href="https://github.com/openai/whisper">Whisper v3</a></td>
|
||||||
|
<td><code>kaldifeat.WhisperFbankOptions</code></td>
|
||||||
|
<td><code>kaldifeat.WhisperFbank</code></td>
|
||||||
|
<td>
|
||||||
|
<pre lang="python">
|
||||||
|
opts = kaldifeat.WhisperFbankOptions()
|
||||||
|
opts.num_mels = 128
|
||||||
|
opts.device = torch.device('cuda', 0)
|
||||||
|
fbank = kaldifeat.WhisperFbank(opts)
|
||||||
|
features = fbank(wave)
|
||||||
|
</pre>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
|
||||||
<tr>
|
<tr>
|
||||||
<td>FBANK</td>
|
<td>FBANK</td>
|
||||||
<td><code>kaldifeat.FbankOptions</code></td>
|
<td><code>kaldifeat.FbankOptions</code></td>
|
||||||
|
@ -1 +1 @@
|
|||||||
exclude_files=whisper-mel-bank.h
|
exclude_files=whisper-mel-bank.h,whisper-v3-mel-bank.h
|
||||||
|
39
kaldifeat/csrc/generate-whisper-melbank-v3.py
Executable file
39
kaldifeat/csrc/generate-whisper-melbank-v3.py
Executable file
@ -0,0 +1,39 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
m = librosa.filters.mel(sr=16000, n_fft=400, n_mels=128)
|
||||||
|
assert m.shape == (128, 201)
|
||||||
|
s = "// Auto-generated. Do NOT edit!\n\n"
|
||||||
|
s += "// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)\n\n"
|
||||||
|
s += "\n"
|
||||||
|
s += "#ifndef KALDIFEAT_CSRC_WHISPER_V3_MEL_BANK_H_\n"
|
||||||
|
s += "#define KALDIFEAT_CSRC_WHISPER_V3_MEL_BANK_H_\n"
|
||||||
|
s += "namespace kaldifeat {\n\n"
|
||||||
|
s += f"constexpr int32_t kWhisperV3MelRows = {m.shape[0]};\n"
|
||||||
|
s += f"constexpr int32_t kWhisperV3MelCols = {m.shape[1]};\n"
|
||||||
|
s += "\n"
|
||||||
|
s += "constexpr float kWhisperV3MelArray[] = {\n"
|
||||||
|
sep = ""
|
||||||
|
for i, f in enumerate(m.reshape(-1).tolist()):
|
||||||
|
s += f"{sep}{f:.8f}"
|
||||||
|
sep = ", "
|
||||||
|
if i and i % 7 == 0:
|
||||||
|
s += ",\n"
|
||||||
|
sep = ""
|
||||||
|
|
||||||
|
s += "};\n\n"
|
||||||
|
s += "} // namespace kaldifeat\n\n"
|
||||||
|
s += "#endif // KALDIFEAT_CSRC_WHISPER_V3_MEL_BANK_H_\n"
|
||||||
|
|
||||||
|
with open("whisper-v3-mel-bank.h", "w") as f:
|
||||||
|
f.write(s)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -23,6 +23,7 @@
|
|||||||
|
|
||||||
#include "kaldifeat/csrc/mel-computations.h"
|
#include "kaldifeat/csrc/mel-computations.h"
|
||||||
#include "kaldifeat/csrc/whisper-mel-bank.h"
|
#include "kaldifeat/csrc/whisper-mel-bank.h"
|
||||||
|
#include "kaldifeat/csrc/whisper-v3-mel-bank.h"
|
||||||
|
|
||||||
#ifndef M_2PI
|
#ifndef M_2PI
|
||||||
#define M_2PI 6.283185307179586476925286766559005
|
#define M_2PI 6.283185307179586476925286766559005
|
||||||
@ -31,9 +32,18 @@
|
|||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
|
|
||||||
WhisperFbankComputer::WhisperFbankComputer(const WhisperFbankOptions &opts)
|
WhisperFbankComputer::WhisperFbankComputer(const WhisperFbankOptions &opts)
|
||||||
: opts_(opts),
|
: opts_(opts) {
|
||||||
mel_banks_(kWhisperMelArray, kWhisperMelRows, kWhisperMelCols,
|
if (opts.num_mels == 80) {
|
||||||
opts.device) {
|
mel_banks_ = std::make_unique<MelBanks>(kWhisperMelArray, kWhisperMelRows,
|
||||||
|
kWhisperMelCols, opts.device);
|
||||||
|
} else if (opts.num_mels == 128) {
|
||||||
|
mel_banks_ = std::make_unique<MelBanks>(
|
||||||
|
kWhisperV3MelArray, kWhisperV3MelRows, kWhisperV3MelCols, opts.device);
|
||||||
|
} else {
|
||||||
|
KALDIFEAT_ERR << "Unsupported num_mels: " << opts.num_mels
|
||||||
|
<< ". Support only 80 and 128";
|
||||||
|
}
|
||||||
|
|
||||||
opts_.frame_opts.samp_freq = 16000;
|
opts_.frame_opts.samp_freq = 16000;
|
||||||
opts_.frame_opts.frame_shift_ms = 10;
|
opts_.frame_opts.frame_shift_ms = 10;
|
||||||
opts_.frame_opts.frame_length_ms = 25;
|
opts_.frame_opts.frame_length_ms = 25;
|
||||||
@ -67,7 +77,7 @@ torch::Tensor WhisperFbankComputer::Compute(
|
|||||||
torch::Tensor power = (real.square() + imag.square());
|
torch::Tensor power = (real.square() + imag.square());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
torch::Tensor mel_energies = mel_banks_.Compute(power);
|
torch::Tensor mel_energies = mel_banks_->Compute(power);
|
||||||
torch::Tensor log_spec = torch::clamp_min(mel_energies, 1e-10).log10();
|
torch::Tensor log_spec = torch::clamp_min(mel_energies, 1e-10).log10();
|
||||||
log_spec = torch::maximum(log_spec, log_spec.max() - 8.0);
|
log_spec = torch::maximum(log_spec, log_spec.max() - 8.0);
|
||||||
torch::Tensor mel = (log_spec + 4.0) / 4.0;
|
torch::Tensor mel = (log_spec + 4.0) / 4.0;
|
||||||
|
@ -19,6 +19,7 @@
|
|||||||
#ifndef KALDIFEAT_CSRC_WHISPER_FBANK_H_
|
#ifndef KALDIFEAT_CSRC_WHISPER_FBANK_H_
|
||||||
#define KALDIFEAT_CSRC_WHISPER_FBANK_H_
|
#define KALDIFEAT_CSRC_WHISPER_FBANK_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -30,12 +31,15 @@ namespace kaldifeat {
|
|||||||
|
|
||||||
struct WhisperFbankOptions {
|
struct WhisperFbankOptions {
|
||||||
FrameExtractionOptions frame_opts;
|
FrameExtractionOptions frame_opts;
|
||||||
|
// for large v3, please use 128
|
||||||
|
int32_t num_mels = 80;
|
||||||
|
|
||||||
torch::Device device{"cpu"};
|
torch::Device device{"cpu"};
|
||||||
std::string ToString() const {
|
std::string ToString() const {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << "WhisperFbankOptions(";
|
os << "WhisperFbankOptions(";
|
||||||
os << "frame_opts=" << frame_opts.ToString() << ", ";
|
os << "frame_opts=" << frame_opts.ToString() << ", ";
|
||||||
|
os << "num_mels=" << num_mels << ", ";
|
||||||
os << "device=\"" << device << "\")";
|
os << "device=\"" << device << "\")";
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
@ -64,7 +68,7 @@ class WhisperFbankComputer {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
WhisperFbankOptions opts_;
|
WhisperFbankOptions opts_;
|
||||||
MelBanks mel_banks_;
|
std::unique_ptr<MelBanks> mel_banks_;
|
||||||
};
|
};
|
||||||
|
|
||||||
using WhisperFbank = OfflineFeatureTpl<WhisperFbankComputer>;
|
using WhisperFbank = OfflineFeatureTpl<WhisperFbankComputer>;
|
||||||
|
3693
kaldifeat/csrc/whisper-v3-mel-bank.h
Normal file
3693
kaldifeat/csrc/whisper-v3-mel-bank.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -23,7 +23,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
|
|||||||
PybindFeatureWindow(m);
|
PybindFeatureWindow(m);
|
||||||
PybindMelComputations(m);
|
PybindMelComputations(m);
|
||||||
PybindFeatureFbank(m);
|
PybindFeatureFbank(m);
|
||||||
PybindWhisperFbank(m);
|
PybindWhisperFbank(&m);
|
||||||
PybindFeatureMfcc(m);
|
PybindFeatureMfcc(m);
|
||||||
PybindFeaturePlp(m);
|
PybindFeaturePlp(m);
|
||||||
PybindFeatureSpectrogram(m);
|
PybindFeatureSpectrogram(m);
|
||||||
|
@ -130,6 +130,8 @@ WhisperFbankOptions WhisperFbankOptionsFromDict(py::dict dict) {
|
|||||||
opts.frame_opts = FrameExtractionOptionsFromDict(dict["frame_opts"]);
|
opts.frame_opts = FrameExtractionOptionsFromDict(dict["frame_opts"]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FROM_DICT(int_, num_mels);
|
||||||
|
|
||||||
if (dict.contains("device")) {
|
if (dict.contains("device")) {
|
||||||
opts.device = torch::Device(std::string(py::str(dict["device"])));
|
opts.device = torch::Device(std::string(py::str(dict["device"])));
|
||||||
}
|
}
|
||||||
@ -142,6 +144,8 @@ py::dict AsDict(const WhisperFbankOptions &opts) {
|
|||||||
|
|
||||||
dict["frame_opts"] = AsDict(opts.frame_opts);
|
dict["frame_opts"] = AsDict(opts.frame_opts);
|
||||||
|
|
||||||
|
AS_DICT(num_mels);
|
||||||
|
|
||||||
auto torch_device = py::module_::import("torch").attr("device");
|
auto torch_device = py::module_::import("torch").attr("device");
|
||||||
dict["device"] = torch_device(opts.device.str());
|
dict["device"] = torch_device(opts.device.str());
|
||||||
|
|
||||||
|
@ -12,16 +12,18 @@
|
|||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
|
|
||||||
static void PybindWhisperFbankOptions(py::module &m) {
|
static void PybindWhisperFbankOptions(py::module *m) {
|
||||||
using PyClass = WhisperFbankOptions;
|
using PyClass = WhisperFbankOptions;
|
||||||
py::class_<PyClass>(m, "WhisperFbankOptions")
|
py::class_<PyClass>(*m, "WhisperFbankOptions")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
.def(py::init([](const FrameExtractionOptions &frame_opts =
|
.def(py::init([](const FrameExtractionOptions &frame_opts =
|
||||||
FrameExtractionOptions(),
|
FrameExtractionOptions(),
|
||||||
|
int32_t num_mels = 80,
|
||||||
py::object device = py::str(
|
py::object device = py::str(
|
||||||
"cpu")) -> std::unique_ptr<WhisperFbankOptions> {
|
"cpu")) -> std::unique_ptr<WhisperFbankOptions> {
|
||||||
auto opts = std::make_unique<WhisperFbankOptions>();
|
auto opts = std::make_unique<WhisperFbankOptions>();
|
||||||
opts->frame_opts = frame_opts;
|
opts->frame_opts = frame_opts;
|
||||||
|
opts->num_mels = num_mels;
|
||||||
|
|
||||||
std::string s = static_cast<py::str>(device);
|
std::string s = static_cast<py::str>(device);
|
||||||
opts->device = torch::Device(s);
|
opts->device = torch::Device(s);
|
||||||
@ -29,8 +31,9 @@ static void PybindWhisperFbankOptions(py::module &m) {
|
|||||||
return opts;
|
return opts;
|
||||||
}),
|
}),
|
||||||
py::arg("frame_opts") = FrameExtractionOptions(),
|
py::arg("frame_opts") = FrameExtractionOptions(),
|
||||||
py::arg("device") = py::str("cpu"))
|
py::arg("num_mels") = 80, py::arg("device") = py::str("cpu"))
|
||||||
.def_readwrite("frame_opts", &PyClass::frame_opts)
|
.def_readwrite("frame_opts", &PyClass::frame_opts)
|
||||||
|
.def_readwrite("num_mels", &PyClass::num_mels)
|
||||||
.def_property(
|
.def_property(
|
||||||
"device",
|
"device",
|
||||||
[](const PyClass &self) -> py::object {
|
[](const PyClass &self) -> py::object {
|
||||||
@ -56,9 +59,9 @@ static void PybindWhisperFbankOptions(py::module &m) {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void PybindWhisperFbankImpl(py::module &m) {
|
static void PybindWhisperFbankImpl(py::module *m) {
|
||||||
using PyClass = WhisperFbank;
|
using PyClass = WhisperFbank;
|
||||||
py::class_<PyClass>(m, "WhisperFbank")
|
py::class_<PyClass>(*m, "WhisperFbank")
|
||||||
.def(py::init<const WhisperFbankOptions &>(), py::arg("opts"))
|
.def(py::init<const WhisperFbankOptions &>(), py::arg("opts"))
|
||||||
.def("dim", &PyClass::Dim)
|
.def("dim", &PyClass::Dim)
|
||||||
.def_property_readonly("options", &PyClass::GetOptions)
|
.def_property_readonly("options", &PyClass::GetOptions)
|
||||||
@ -73,7 +76,7 @@ static void PybindWhisperFbankImpl(py::module &m) {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PybindWhisperFbank(py::module &m) {
|
void PybindWhisperFbank(py::module *m) {
|
||||||
PybindWhisperFbankOptions(m);
|
PybindWhisperFbankOptions(m);
|
||||||
PybindWhisperFbankImpl(m);
|
PybindWhisperFbankImpl(m);
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@
|
|||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
|
|
||||||
void PybindWhisperFbank(py::module &m);
|
void PybindWhisperFbank(py::module *m);
|
||||||
|
|
||||||
} // namespace kaldifeat
|
} // namespace kaldifeat
|
||||||
|
|
||||||
|
49
kaldifeat/python/tests/test_whisper_v3_fbank.py
Normal file
49
kaldifeat/python/tests/test_whisper_v3_fbank.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import torch
|
||||||
|
import kaldifeat
|
||||||
|
|
||||||
|
|
||||||
|
def get_ground_truth(x):
|
||||||
|
N_FFT = 400
|
||||||
|
HOP_LENGTH = 160
|
||||||
|
|
||||||
|
m = librosa.filters.mel(sr=16000, n_fft=400, n_mels=128)
|
||||||
|
m = torch.from_numpy(m)
|
||||||
|
# print(m.shape) # [128, 201]
|
||||||
|
window = torch.hann_window(N_FFT)
|
||||||
|
stft = torch.stft(x, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||||
|
# print(stft.shape) # [201, 301]
|
||||||
|
magnitudes = stft[..., :-1].abs() ** 2
|
||||||
|
# print(magnitudes.shape) # [201, 300]
|
||||||
|
|
||||||
|
mel_spec = m @ magnitudes
|
||||||
|
# print(mel_spec.shape) # [128, 300]
|
||||||
|
|
||||||
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||||
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
|
log_spec = (log_spec + 4.0) / 4.0
|
||||||
|
return log_spec.t()
|
||||||
|
|
||||||
|
|
||||||
|
def test_whisper_v3_fbank():
|
||||||
|
x = torch.rand(16000 * 3)
|
||||||
|
|
||||||
|
gt = get_ground_truth(x)
|
||||||
|
print(gt.shape) # [300, 128]
|
||||||
|
|
||||||
|
opts = kaldifeat.WhisperFbankOptions(num_mels=128, device="cpu")
|
||||||
|
print(opts)
|
||||||
|
whisper_fbank = kaldifeat.WhisperFbank(opts)
|
||||||
|
y = whisper_fbank(x) # [298, 128]
|
||||||
|
print(y.shape) # [298, 128]
|
||||||
|
|
||||||
|
print(gt[:5, :5])
|
||||||
|
print(y[:5, :5])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(20231109)
|
||||||
|
test_whisper_v3_fbank()
|
@ -1,6 +1,6 @@
|
|||||||
package:
|
package:
|
||||||
name: kaldifeat
|
name: kaldifeat
|
||||||
version: "1.25.2"
|
version: "1.25.3"
|
||||||
|
|
||||||
source:
|
source:
|
||||||
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"
|
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
package:
|
package:
|
||||||
name: kaldifeat
|
name: kaldifeat
|
||||||
version: "1.25.2"
|
version: "1.25.3"
|
||||||
|
|
||||||
source:
|
source:
|
||||||
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"
|
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user