mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 10:02:20 +00:00
Support computing features for whisper (#82)
This commit is contained in:
parent
7912c2f442
commit
01aed93b1b
@ -7,7 +7,7 @@ project(kaldifeat)
|
|||||||
# remember to change the version in
|
# remember to change the version in
|
||||||
# scripts/conda/kaldifeat/meta.yaml
|
# scripts/conda/kaldifeat/meta.yaml
|
||||||
# scripts/conda-cpu/kaldifeat/meta.yaml
|
# scripts/conda-cpu/kaldifeat/meta.yaml
|
||||||
set(kaldifeat_VERSION "1.25.1")
|
set(kaldifeat_VERSION "1.25.2")
|
||||||
|
|
||||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||||
|
@ -10,6 +10,7 @@ set(kaldifeat_srcs
|
|||||||
matrix-functions.cc
|
matrix-functions.cc
|
||||||
mel-computations.cc
|
mel-computations.cc
|
||||||
online-feature.cc
|
online-feature.cc
|
||||||
|
whisper-fbank.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
add_library(kaldifeat_core ${kaldifeat_srcs})
|
add_library(kaldifeat_core ${kaldifeat_srcs})
|
||||||
|
1
kaldifeat/csrc/CPPLINT.cfg
Normal file
1
kaldifeat/csrc/CPPLINT.cfg
Normal file
@ -0,0 +1 @@
|
|||||||
|
exclude_files=whisper-mel-bank.h
|
@ -65,7 +65,7 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
|
|||||||
// note spectrum is in magnitude, not power, because of `abs()`
|
// note spectrum is in magnitude, not power, because of `abs()`
|
||||||
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
|
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
|
||||||
// signal_frame shape: [x, 512]
|
// signal_frame shape: [x, 512]
|
||||||
// spectrum shape [x, 257
|
// spectrum shape [x, 257]
|
||||||
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
|
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
|
||||||
#else
|
#else
|
||||||
// signal_frame shape [x, 512]
|
// signal_frame shape [x, 512]
|
||||||
|
@ -29,6 +29,13 @@ FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts,
|
|||||||
float *window_data = window.data_ptr<float>();
|
float *window_data = window.data_ptr<float>();
|
||||||
|
|
||||||
double a = M_2PI / (frame_length - 1);
|
double a = M_2PI / (frame_length - 1);
|
||||||
|
|
||||||
|
if (opts.window_type == "hann") {
|
||||||
|
// see https://pytorch.org/docs/stable/generated/torch.hann_window.html
|
||||||
|
// We assume periodic is true
|
||||||
|
a = M_2PI / frame_length;
|
||||||
|
}
|
||||||
|
|
||||||
for (int32_t i = 0; i < frame_length; i++) {
|
for (int32_t i = 0; i < frame_length; i++) {
|
||||||
double i_fl = static_cast<double>(i);
|
double i_fl = static_cast<double>(i);
|
||||||
if (opts.window_type == "hanning") {
|
if (opts.window_type == "hanning") {
|
||||||
@ -39,6 +46,8 @@ FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts,
|
|||||||
window_data[i] = sin(0.5 * a * i_fl);
|
window_data[i] = sin(0.5 * a * i_fl);
|
||||||
} else if (opts.window_type == "hamming") {
|
} else if (opts.window_type == "hamming") {
|
||||||
window_data[i] = 0.54 - 0.46 * cos(a * i_fl);
|
window_data[i] = 0.54 - 0.46 * cos(a * i_fl);
|
||||||
|
} else if (opts.window_type == "hann") {
|
||||||
|
window_data[i] = 0.50 - 0.50 * cos(a * i_fl);
|
||||||
} else if (opts.window_type ==
|
} else if (opts.window_type ==
|
||||||
"povey") { // like hamming but goes to zero at edges.
|
"povey") { // like hamming but goes to zero at edges.
|
||||||
window_data[i] = pow(0.5 - 0.5 * cos(a * i_fl), 0.85);
|
window_data[i] = pow(0.5 - 0.5 * cos(a * i_fl), 0.85);
|
||||||
|
39
kaldifeat/csrc/generate-whisper-melbank.py
Executable file
39
kaldifeat/csrc/generate-whisper-melbank.py
Executable file
@ -0,0 +1,39 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
m = librosa.filters.mel(sr=16000, n_fft=400, n_mels=80)
|
||||||
|
assert m.shape == (80, 201)
|
||||||
|
s = "// Auto-generated. Do NOT edit!\n\n"
|
||||||
|
s += "// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)\n\n"
|
||||||
|
s += "\n"
|
||||||
|
s += "#ifndef KALDIFEAT_CSRC_WHISPER_MEL_BANK_H_\n"
|
||||||
|
s += "#define KALDIFEAT_CSRC_WHISPER_MEL_BANK_H_\n"
|
||||||
|
s += "namespace kaldifeat {\n\n"
|
||||||
|
s += f"constexpr int32_t kWhisperMelRows = {m.shape[0]};\n"
|
||||||
|
s += f"constexpr int32_t kWhisperMelCols = {m.shape[1]};\n"
|
||||||
|
s += "\n"
|
||||||
|
s += "constexpr float kWhisperMelArray[] = {\n"
|
||||||
|
sep = ""
|
||||||
|
for i, f in enumerate(m.reshape(-1).tolist()):
|
||||||
|
s += f"{sep}{f:.8f}"
|
||||||
|
sep = ", "
|
||||||
|
if i and i % 7 == 0:
|
||||||
|
s += ",\n"
|
||||||
|
sep = ""
|
||||||
|
|
||||||
|
s += "};\n\n"
|
||||||
|
s += "} // namespace kaldifeat\n\n"
|
||||||
|
s += "#endif // KALDIFEAT_CSRC_WHISPER_MEL_BANK_H_\n"
|
||||||
|
|
||||||
|
with open("whisper-mel-bank.h", "w") as f:
|
||||||
|
f.write(s)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -196,6 +196,15 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MelBanks::MelBanks(const float *weights, int32_t num_rows, int32_t num_cols,
|
||||||
|
torch::Device device)
|
||||||
|
: debug_(false), htk_mode_(false) {
|
||||||
|
bins_mat_ = torch::from_blob(const_cast<float *>(weights),
|
||||||
|
{num_rows, num_cols}, torch::kFloat)
|
||||||
|
.t()
|
||||||
|
.to(device);
|
||||||
|
}
|
||||||
|
|
||||||
torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const {
|
torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const {
|
||||||
return torch::mm(spectrum, bins_mat_);
|
return torch::mm(spectrum, bins_mat_);
|
||||||
}
|
}
|
||||||
|
@ -76,6 +76,17 @@ class MelBanks {
|
|||||||
const FrameExtractionOptions &frame_opts, float vtln_warp_factor,
|
const FrameExtractionOptions &frame_opts, float vtln_warp_factor,
|
||||||
torch::Device device);
|
torch::Device device);
|
||||||
|
|
||||||
|
// Initialize with a 2-d weights matrix
|
||||||
|
//
|
||||||
|
// Note: This constructor is for Whisper. It does not initialize
|
||||||
|
// center_freqs_.
|
||||||
|
//
|
||||||
|
// @param weights Pointer to the start address of the matrix
|
||||||
|
// @param num_rows It equals to number of mel bins
|
||||||
|
// @param num_cols It equals to (number of fft bins)/2+1
|
||||||
|
MelBanks(const float *weights, int32_t num_rows, int32_t num_cols,
|
||||||
|
torch::Device device);
|
||||||
|
|
||||||
// CAUTION: we save a transposed version of bins_mat_, so return size(1) here
|
// CAUTION: we save a transposed version of bins_mat_, so return size(1) here
|
||||||
int32_t NumBins() const { return static_cast<int32_t>(bins_mat_.size(1)); }
|
int32_t NumBins() const { return static_cast<int32_t>(bins_mat_.size(1)); }
|
||||||
|
|
||||||
@ -89,7 +100,8 @@ class MelBanks {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// A 2-D matrix. Its shape is NOT [num_bins, num_fft_bins]
|
// A 2-D matrix. Its shape is NOT [num_bins, num_fft_bins]
|
||||||
// Its shape is [num_fft_bins, num_bins].
|
// Its shape is [num_fft_bins, num_bins] for non-whisper.
|
||||||
|
// For whisper, its shape is [num_fft_bins/2+1, num_bins]
|
||||||
torch::Tensor bins_mat_;
|
torch::Tensor bins_mat_;
|
||||||
|
|
||||||
// center frequencies of bins, numbered from 0 ... num_bins-1.
|
// center frequencies of bins, numbered from 0 ... num_bins-1.
|
||||||
|
78
kaldifeat/csrc/whisper-fbank.cc
Normal file
78
kaldifeat/csrc/whisper-fbank.cc
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
*
|
||||||
|
* See LICENSE for clarification regarding multiple authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/whisper-fbank.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/mel-computations.h"
|
||||||
|
#include "kaldifeat/csrc/whisper-mel-bank.h"
|
||||||
|
|
||||||
|
#ifndef M_2PI
|
||||||
|
#define M_2PI 6.283185307179586476925286766559005
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
WhisperFbankComputer::WhisperFbankComputer(const WhisperFbankOptions &opts)
|
||||||
|
: opts_(opts),
|
||||||
|
mel_banks_(kWhisperMelArray, kWhisperMelRows, kWhisperMelCols,
|
||||||
|
opts.device) {
|
||||||
|
opts_.frame_opts.samp_freq = 16000;
|
||||||
|
opts_.frame_opts.frame_shift_ms = 10;
|
||||||
|
opts_.frame_opts.frame_length_ms = 25;
|
||||||
|
opts_.frame_opts.dither = 0;
|
||||||
|
opts_.frame_opts.preemph_coeff = 0;
|
||||||
|
opts_.frame_opts.remove_dc_offset = false;
|
||||||
|
opts_.frame_opts.window_type = "hann";
|
||||||
|
opts_.frame_opts.round_to_power_of_two = false;
|
||||||
|
opts_.frame_opts.snip_edges = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor WhisperFbankComputer::Compute(
|
||||||
|
torch::Tensor /*signal_raw_log_energy*/, float /*vtln_warp*/,
|
||||||
|
const torch::Tensor &signal_frame) {
|
||||||
|
KALDIFEAT_ASSERT(signal_frame.dim() == 2);
|
||||||
|
KALDIFEAT_ASSERT(signal_frame.size(1) == opts_.frame_opts.PaddedWindowSize());
|
||||||
|
|
||||||
|
// note spectrum is in magnitude, not power, because of `abs()`
|
||||||
|
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
|
||||||
|
// signal_frame shape: [x, 512]
|
||||||
|
// power shape [x, 257]
|
||||||
|
torch::Tensor power = torch::fft::rfft(signal_frame).abs().pow(2);
|
||||||
|
#else
|
||||||
|
// signal_frame shape [x, 512]
|
||||||
|
// real_imag shape [x, 257, 2],
|
||||||
|
// where [..., 0] is the real part
|
||||||
|
// [..., 1] is the imaginary part
|
||||||
|
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
|
||||||
|
torch::Tensor real = real_imag.index({"...", 0});
|
||||||
|
torch::Tensor imag = real_imag.index({"...", 1});
|
||||||
|
torch::Tensor power = (real.square() + imag.square());
|
||||||
|
#endif
|
||||||
|
|
||||||
|
torch::Tensor mel_energies = mel_banks_.Compute(power);
|
||||||
|
torch::Tensor log_spec = torch::clamp_min(mel_energies, 1e-10).log10();
|
||||||
|
log_spec = torch::maximum(log_spec, log_spec.max() - 8.0);
|
||||||
|
torch::Tensor mel = (log_spec + 4.0) / 4.0;
|
||||||
|
|
||||||
|
return mel;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
74
kaldifeat/csrc/whisper-fbank.h
Normal file
74
kaldifeat/csrc/whisper-fbank.h
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
*
|
||||||
|
* See LICENSE for clarification regarding multiple authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef KALDIFEAT_CSRC_WHISPER_FBANK_H_
|
||||||
|
#define KALDIFEAT_CSRC_WHISPER_FBANK_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/feature-common.h"
|
||||||
|
#include "kaldifeat/csrc/feature-window.h"
|
||||||
|
#include "kaldifeat/csrc/mel-computations.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
struct WhisperFbankOptions {
|
||||||
|
FrameExtractionOptions frame_opts;
|
||||||
|
|
||||||
|
torch::Device device{"cpu"};
|
||||||
|
std::string ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "WhisperFbankOptions(";
|
||||||
|
os << "frame_opts=" << frame_opts.ToString() << ", ";
|
||||||
|
os << "device=\"" << device << "\")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class WhisperFbankComputer {
|
||||||
|
public:
|
||||||
|
// note: Only frame_opts.device is used. All other fields from frame_opts
|
||||||
|
// are ignored
|
||||||
|
explicit WhisperFbankComputer(const WhisperFbankOptions &opts = {});
|
||||||
|
|
||||||
|
int32_t Dim() const { return 80; }
|
||||||
|
|
||||||
|
const FrameExtractionOptions &GetFrameOptions() const {
|
||||||
|
return opts_.frame_opts;
|
||||||
|
}
|
||||||
|
|
||||||
|
const WhisperFbankOptions &GetOptions() const { return opts_; }
|
||||||
|
|
||||||
|
torch::Tensor Compute(torch::Tensor /*signal_raw_log_energy*/,
|
||||||
|
float /*vtln_warp*/, const torch::Tensor &signal_frame);
|
||||||
|
|
||||||
|
// if true, compute log_energy_pre_window but after dithering and dc removal
|
||||||
|
bool NeedRawLogEnergy() const { return false; }
|
||||||
|
using Options = WhisperFbankOptions;
|
||||||
|
|
||||||
|
private:
|
||||||
|
WhisperFbankOptions opts_;
|
||||||
|
MelBanks mel_banks_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using WhisperFbank = OfflineFeatureTpl<WhisperFbankComputer>;
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
||||||
|
|
||||||
|
#endif // KALDIFEAT_CSRC_WHISPER_FBANK_H_
|
2315
kaldifeat/csrc/whisper-mel-bank.h
Normal file
2315
kaldifeat/csrc/whisper-mel-bank.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -9,6 +9,7 @@ pybind11_add_module(_kaldifeat
|
|||||||
mel-computations.cc
|
mel-computations.cc
|
||||||
online-feature.cc
|
online-feature.cc
|
||||||
utils.cc
|
utils.cc
|
||||||
|
whisper-fbank.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
if(APPLE)
|
if(APPLE)
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
#include "kaldifeat/python/csrc/feature-window.h"
|
#include "kaldifeat/python/csrc/feature-window.h"
|
||||||
#include "kaldifeat/python/csrc/mel-computations.h"
|
#include "kaldifeat/python/csrc/mel-computations.h"
|
||||||
#include "kaldifeat/python/csrc/online-feature.h"
|
#include "kaldifeat/python/csrc/online-feature.h"
|
||||||
|
#include "kaldifeat/python/csrc/whisper-fbank.h"
|
||||||
#include "torch/torch.h"
|
#include "torch/torch.h"
|
||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
@ -22,6 +23,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
|
|||||||
PybindFeatureWindow(m);
|
PybindFeatureWindow(m);
|
||||||
PybindMelComputations(m);
|
PybindMelComputations(m);
|
||||||
PybindFeatureFbank(m);
|
PybindFeatureFbank(m);
|
||||||
|
PybindWhisperFbank(m);
|
||||||
PybindFeatureMfcc(m);
|
PybindFeatureMfcc(m);
|
||||||
PybindFeaturePlp(m);
|
PybindFeaturePlp(m);
|
||||||
PybindFeatureSpectrogram(m);
|
PybindFeatureSpectrogram(m);
|
||||||
|
@ -123,6 +123,31 @@ py::dict AsDict(const FbankOptions &opts) {
|
|||||||
return dict;
|
return dict;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
WhisperFbankOptions WhisperFbankOptionsFromDict(py::dict dict) {
|
||||||
|
WhisperFbankOptions opts;
|
||||||
|
|
||||||
|
if (dict.contains("frame_opts")) {
|
||||||
|
opts.frame_opts = FrameExtractionOptionsFromDict(dict["frame_opts"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dict.contains("device")) {
|
||||||
|
opts.device = torch::Device(std::string(py::str(dict["device"])));
|
||||||
|
}
|
||||||
|
|
||||||
|
return opts;
|
||||||
|
}
|
||||||
|
|
||||||
|
py::dict AsDict(const WhisperFbankOptions &opts) {
|
||||||
|
py::dict dict;
|
||||||
|
|
||||||
|
dict["frame_opts"] = AsDict(opts.frame_opts);
|
||||||
|
|
||||||
|
auto torch_device = py::module_::import("torch").attr("device");
|
||||||
|
dict["device"] = torch_device(opts.device.str());
|
||||||
|
|
||||||
|
return dict;
|
||||||
|
}
|
||||||
|
|
||||||
MfccOptions MfccOptionsFromDict(py::dict dict) {
|
MfccOptions MfccOptionsFromDict(py::dict dict) {
|
||||||
MfccOptions opts;
|
MfccOptions opts;
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
#include "kaldifeat/csrc/feature-spectrogram.h"
|
#include "kaldifeat/csrc/feature-spectrogram.h"
|
||||||
#include "kaldifeat/csrc/feature-window.h"
|
#include "kaldifeat/csrc/feature-window.h"
|
||||||
#include "kaldifeat/csrc/mel-computations.h"
|
#include "kaldifeat/csrc/mel-computations.h"
|
||||||
|
#include "kaldifeat/csrc/whisper-fbank.h"
|
||||||
#include "kaldifeat/python/csrc/kaldifeat.h"
|
#include "kaldifeat/python/csrc/kaldifeat.h"
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -36,6 +37,9 @@ py::dict AsDict(const MelBanksOptions &opts);
|
|||||||
FbankOptions FbankOptionsFromDict(py::dict dict);
|
FbankOptions FbankOptionsFromDict(py::dict dict);
|
||||||
py::dict AsDict(const FbankOptions &opts);
|
py::dict AsDict(const FbankOptions &opts);
|
||||||
|
|
||||||
|
WhisperFbankOptions WhisperFbankOptionsFromDict(py::dict dict);
|
||||||
|
py::dict AsDict(const WhisperFbankOptions &opts);
|
||||||
|
|
||||||
MfccOptions MfccOptionsFromDict(py::dict dict);
|
MfccOptions MfccOptionsFromDict(py::dict dict);
|
||||||
py::dict AsDict(const MfccOptions &opts);
|
py::dict AsDict(const MfccOptions &opts);
|
||||||
|
|
||||||
|
81
kaldifeat/python/csrc/whisper-fbank.cc
Normal file
81
kaldifeat/python/csrc/whisper-fbank.cc
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
// kaldifeat/python/csrc/whisper-fbank.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
#include "kaldifeat/python/csrc/whisper-fbank.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/whisper-fbank.h"
|
||||||
|
#include "kaldifeat/python/csrc/utils.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
static void PybindWhisperFbankOptions(py::module &m) {
|
||||||
|
using PyClass = WhisperFbankOptions;
|
||||||
|
py::class_<PyClass>(m, "WhisperFbankOptions")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def(py::init([](const FrameExtractionOptions &frame_opts =
|
||||||
|
FrameExtractionOptions(),
|
||||||
|
py::object device = py::str(
|
||||||
|
"cpu")) -> std::unique_ptr<WhisperFbankOptions> {
|
||||||
|
auto opts = std::make_unique<WhisperFbankOptions>();
|
||||||
|
opts->frame_opts = frame_opts;
|
||||||
|
|
||||||
|
std::string s = static_cast<py::str>(device);
|
||||||
|
opts->device = torch::Device(s);
|
||||||
|
|
||||||
|
return opts;
|
||||||
|
}),
|
||||||
|
py::arg("frame_opts") = FrameExtractionOptions(),
|
||||||
|
py::arg("device") = py::str("cpu"))
|
||||||
|
.def_readwrite("frame_opts", &PyClass::frame_opts)
|
||||||
|
.def_property(
|
||||||
|
"device",
|
||||||
|
[](const PyClass &self) -> py::object {
|
||||||
|
py::object ans = py::module_::import("torch").attr("device");
|
||||||
|
return ans(self.device.str());
|
||||||
|
},
|
||||||
|
[](PyClass &self, py::object obj) -> void {
|
||||||
|
std::string s = static_cast<py::str>(obj);
|
||||||
|
self.device = torch::Device(s);
|
||||||
|
})
|
||||||
|
.def("__str__",
|
||||||
|
[](const PyClass &self) -> std::string { return self.ToString(); })
|
||||||
|
.def("as_dict",
|
||||||
|
[](const PyClass &self) -> py::dict { return AsDict(self); })
|
||||||
|
.def_static("from_dict",
|
||||||
|
[](py::dict dict) -> PyClass {
|
||||||
|
return WhisperFbankOptionsFromDict(dict);
|
||||||
|
})
|
||||||
|
.def(py::pickle(
|
||||||
|
[](const PyClass &self) -> py::dict { return AsDict(self); },
|
||||||
|
[](py::dict dict) -> PyClass {
|
||||||
|
return WhisperFbankOptionsFromDict(dict);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void PybindWhisperFbankImpl(py::module &m) {
|
||||||
|
using PyClass = WhisperFbank;
|
||||||
|
py::class_<PyClass>(m, "WhisperFbank")
|
||||||
|
.def(py::init<const WhisperFbankOptions &>(), py::arg("opts"))
|
||||||
|
.def("dim", &PyClass::Dim)
|
||||||
|
.def_property_readonly("options", &PyClass::GetOptions)
|
||||||
|
.def("compute_features", &PyClass::ComputeFeatures, py::arg("wave"),
|
||||||
|
py::arg("vtln_warp"), py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def(py::pickle(
|
||||||
|
[](const PyClass &self) -> py::dict {
|
||||||
|
return AsDict(self.GetOptions());
|
||||||
|
},
|
||||||
|
[](py::dict dict) -> std::unique_ptr<PyClass> {
|
||||||
|
return std::make_unique<PyClass>(WhisperFbankOptionsFromDict(dict));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
void PybindWhisperFbank(py::module &m) {
|
||||||
|
PybindWhisperFbankOptions(m);
|
||||||
|
PybindWhisperFbankImpl(m);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
16
kaldifeat/python/csrc/whisper-fbank.h
Normal file
16
kaldifeat/python/csrc/whisper-fbank.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
// kaldifeat/python/csrc/whisper-fbank.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
#ifndef KALDIFEAT_PYTHON_CSRC_WHISPER_FBANK_H_
|
||||||
|
#define KALDIFEAT_PYTHON_CSRC_WHISPER_FBANK_H_
|
||||||
|
|
||||||
|
#include "kaldifeat/python/csrc/kaldifeat.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
void PybindWhisperFbank(py::module &m);
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
||||||
|
|
||||||
|
#endif // KALDIFEAT_PYTHON_CSRC_WHISPER_FBANK_H_
|
@ -17,6 +17,7 @@ from _kaldifeat import (
|
|||||||
MfccOptions,
|
MfccOptions,
|
||||||
PlpOptions,
|
PlpOptions,
|
||||||
SpectrogramOptions,
|
SpectrogramOptions,
|
||||||
|
WhisperFbankOptions,
|
||||||
num_frames,
|
num_frames,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -26,6 +27,7 @@ from .offline_feature import OfflineFeature
|
|||||||
from .online_feature import OnlineFeature
|
from .online_feature import OnlineFeature
|
||||||
from .plp import OnlinePlp, Plp
|
from .plp import OnlinePlp, Plp
|
||||||
from .spectrogram import Spectrogram
|
from .spectrogram import Spectrogram
|
||||||
|
from .whisper_fbank import WhisperFbank
|
||||||
|
|
||||||
cmake_prefix_path = _Path(__file__).parent / "share" / "cmake"
|
cmake_prefix_path = _Path(__file__).parent / "share" / "cmake"
|
||||||
del _Path
|
del _Path
|
||||||
|
12
kaldifeat/python/kaldifeat/whisper_fbank.py
Normal file
12
kaldifeat/python/kaldifeat/whisper_fbank.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
|
||||||
|
import _kaldifeat
|
||||||
|
|
||||||
|
from .offline_feature import OfflineFeature
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperFbank(OfflineFeature):
|
||||||
|
def __init__(self, opts: _kaldifeat.WhisperFbankOptions):
|
||||||
|
super().__init__(opts)
|
||||||
|
self.computer = _kaldifeat.WhisperFbank(opts)
|
48
kaldifeat/python/tests/test_whisper_fbank.py
Normal file
48
kaldifeat/python/tests/test_whisper_fbank.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import torch
|
||||||
|
import kaldifeat
|
||||||
|
|
||||||
|
|
||||||
|
def get_ground_truth(x):
|
||||||
|
N_FFT = 400
|
||||||
|
HOP_LENGTH = 160
|
||||||
|
|
||||||
|
m = librosa.filters.mel(sr=16000, n_fft=400, n_mels=80)
|
||||||
|
m = torch.from_numpy(m)
|
||||||
|
# print(m.shape) # [80, 201]
|
||||||
|
window = torch.hann_window(N_FFT)
|
||||||
|
stft = torch.stft(x, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||||
|
# print(stft.shape) # [201, 301]
|
||||||
|
magnitudes = stft[..., :-1].abs() ** 2
|
||||||
|
# print(magnitudes.shape) # [201, 300]
|
||||||
|
|
||||||
|
mel_spec = m @ magnitudes
|
||||||
|
# print(mel_spec.shape) # [80, 300]
|
||||||
|
|
||||||
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||||
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
|
log_spec = (log_spec + 4.0) / 4.0
|
||||||
|
return log_spec.t()
|
||||||
|
|
||||||
|
|
||||||
|
def test_whisper_fbank():
|
||||||
|
x = torch.rand(16000 * 3)
|
||||||
|
|
||||||
|
gt = get_ground_truth(x)
|
||||||
|
print(gt.shape) # [300, 80]
|
||||||
|
|
||||||
|
opts = kaldifeat.WhisperFbankOptions(device="cpu")
|
||||||
|
whisper_fbank = kaldifeat.WhisperFbank(opts)
|
||||||
|
y = whisper_fbank(x) # [298, 80]
|
||||||
|
print(y.shape) # [298, 80]
|
||||||
|
|
||||||
|
# print(gt[:5, :5])
|
||||||
|
# print(y[:5, :5])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(20231108)
|
||||||
|
test_whisper_fbank()
|
@ -1,6 +1,6 @@
|
|||||||
package:
|
package:
|
||||||
name: kaldifeat
|
name: kaldifeat
|
||||||
version: "1.25.1"
|
version: "1.25.2"
|
||||||
|
|
||||||
source:
|
source:
|
||||||
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"
|
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
package:
|
package:
|
||||||
name: kaldifeat
|
name: kaldifeat
|
||||||
version: "1.25.1"
|
version: "1.25.2"
|
||||||
|
|
||||||
source:
|
source:
|
||||||
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"
|
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user