mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-10 02:22:16 +00:00
Add recycling vector.
This commit is contained in:
parent
2399cc8993
commit
3196cff441
@ -9,6 +9,7 @@ set(kaldifeat_srcs
|
||||
feature-window.cc
|
||||
matrix-functions.cc
|
||||
mel-computations.cc
|
||||
online-feature.cc
|
||||
)
|
||||
|
||||
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
|
||||
@ -40,6 +41,7 @@ if(kaldifeat_BUILD_TESTS)
|
||||
# please sort the source files alphabetically
|
||||
set(test_srcs
|
||||
feature-window-test.cc
|
||||
online-feature-test.cc
|
||||
)
|
||||
|
||||
foreach(source IN LISTS test_srcs)
|
||||
|
@ -7,6 +7,7 @@
|
||||
#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
|
||||
#define KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "torch/script.h"
|
||||
|
49
kaldifeat/csrc/online-feature-test.cc
Normal file
49
kaldifeat/csrc/online-feature-test.cc
Normal file
@ -0,0 +1,49 @@
|
||||
// kaldifeat/csrc/online-feature-test.h
|
||||
//
|
||||
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
#include "kaldifeat/csrc/online-feature.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
TEST(RecyclingVector, TestUnlimited) {
|
||||
RecyclingVector v(-1);
|
||||
constexpr int32_t N = 100;
|
||||
for (int32_t i = 0; i != N; ++i) {
|
||||
torch::Tensor t = torch::tensor({i, i + 1, i + 2});
|
||||
v.PushBack(t);
|
||||
}
|
||||
ASSERT_EQ(v.Size(), N);
|
||||
|
||||
for (int32_t i = 0; i != N; ++i) {
|
||||
torch::Tensor t = v.At(i);
|
||||
torch::Tensor expected = torch::tensor({i, i + 1, i + 2});
|
||||
EXPECT_TRUE(t.equal(expected));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(RecyclingVector, Testlimited) {
|
||||
constexpr int32_t K = 3;
|
||||
constexpr int32_t N = 10;
|
||||
RecyclingVector v(K);
|
||||
for (int32_t i = 0; i != N; ++i) {
|
||||
torch::Tensor t = torch::tensor({i, i + 1, i + 2});
|
||||
v.PushBack(t);
|
||||
}
|
||||
|
||||
ASSERT_EQ(v.Size(), N);
|
||||
|
||||
for (int32_t i = 0; i < N - K; ++i) {
|
||||
ASSERT_DEATH(v.At(i), "");
|
||||
}
|
||||
|
||||
for (int32_t i = N - K; i != N; ++i) {
|
||||
torch::Tensor t = v.At(i);
|
||||
torch::Tensor expected = torch::tensor({i, i + 1, i + 2});
|
||||
EXPECT_TRUE(t.equal(expected));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
43
kaldifeat/csrc/online-feature.cc
Normal file
43
kaldifeat/csrc/online-feature.cc
Normal file
@ -0,0 +1,43 @@
|
||||
// kaldifeat/csrc/online-feature.cc
|
||||
//
|
||||
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
// This file is copied/modified from kaldi/src/feat/online-feature.cc
|
||||
|
||||
#include "kaldifeat/csrc/online-feature.h"
|
||||
|
||||
#include "kaldifeat/csrc/log.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
RecyclingVector::RecyclingVector(int32_t items_to_hold)
|
||||
: items_to_hold_(items_to_hold == 0 ? -1 : items_to_hold),
|
||||
first_available_index_(0) {}
|
||||
|
||||
torch::Tensor RecyclingVector::At(int32_t index) const {
|
||||
if (index < first_available_index_) {
|
||||
KALDIFEAT_ERR << "Attempted to retrieve feature vector that was "
|
||||
"already removed by the RecyclingVector (index = "
|
||||
<< index << "; "
|
||||
<< "first_available_index = " << first_available_index_
|
||||
<< "; "
|
||||
<< "size = " << Size() << ")";
|
||||
}
|
||||
// 'at' does size checking.
|
||||
return items_.at(index - first_available_index_);
|
||||
}
|
||||
|
||||
void RecyclingVector::PushBack(torch::Tensor item) {
|
||||
// Note: -1 is a larger number when treated as unsigned
|
||||
if (items_.size() == static_cast<size_t>(items_to_hold_)) {
|
||||
items_.pop_front();
|
||||
++first_available_index_;
|
||||
}
|
||||
items_.push_back(item);
|
||||
}
|
||||
|
||||
int32_t RecyclingVector::Size() const {
|
||||
return first_available_index_ + static_cast<int32_t>(items_.size());
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
49
kaldifeat/csrc/online-feature.h
Normal file
49
kaldifeat/csrc/online-feature.h
Normal file
@ -0,0 +1,49 @@
|
||||
// kaldifeat/csrc/online-feature.h
|
||||
//
|
||||
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
// This file is copied/modified from kaldi/src/feat/online-feature.h
|
||||
|
||||
#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_H_
|
||||
#define KALDIFEAT_CSRC_ONLINE_FEATURE_H_
|
||||
|
||||
#include <deque>
|
||||
|
||||
#include "kaldifeat/csrc/online-feature-itf.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
/// This class serves as a storage for feature vectors with an option to limit
|
||||
/// the memory usage by removing old elements. The deleted frames indices are
|
||||
/// "remembered" so that regardless of the MAX_ITEMS setting, the user always
|
||||
/// provides the indices as if no deletion was being performed.
|
||||
/// This is useful when processing very long recordings which would otherwise
|
||||
/// cause the memory to eventually blow up when the features are not being
|
||||
/// removed.
|
||||
class RecyclingVector {
|
||||
public:
|
||||
/// By default it does not remove any elements.
|
||||
explicit RecyclingVector(int32_t items_to_hold = -1);
|
||||
|
||||
~RecyclingVector() = default;
|
||||
RecyclingVector(const RecyclingVector &) = delete;
|
||||
RecyclingVector &operator=(const RecyclingVector &) = delete;
|
||||
|
||||
torch::Tensor At(int32_t index) const;
|
||||
|
||||
void PushBack(torch::Tensor item);
|
||||
|
||||
/// This method returns the size as if no "recycling" had happened,
|
||||
/// i.e. equivalent to the number of times the PushBack method has been
|
||||
/// called.
|
||||
int32_t Size() const;
|
||||
|
||||
private:
|
||||
std::deque<torch::Tensor> items_;
|
||||
int32_t items_to_hold_;
|
||||
int32_t first_available_index_;
|
||||
};
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_H_
|
Loading…
x
Reference in New Issue
Block a user