Add recycling vector.

This commit is contained in:
Fangjun Kuang 2022-04-02 12:47:57 +08:00
parent 2399cc8993
commit 3196cff441
5 changed files with 144 additions and 0 deletions

View File

@ -9,6 +9,7 @@ set(kaldifeat_srcs
feature-window.cc
matrix-functions.cc
mel-computations.cc
online-feature.cc
)
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
@ -40,6 +41,7 @@ if(kaldifeat_BUILD_TESTS)
# please sort the source files alphabetically
set(test_srcs
feature-window-test.cc
online-feature-test.cc
)
foreach(source IN LISTS test_srcs)

View File

@ -7,6 +7,7 @@
#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
#define KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
#include <utility>
#include <vector>
#include "torch/script.h"

View File

@ -0,0 +1,49 @@
// kaldifeat/csrc/online-feature-test.h
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/csrc/online-feature.h"
#include "gtest/gtest.h"
namespace kaldifeat {
TEST(RecyclingVector, TestUnlimited) {
RecyclingVector v(-1);
constexpr int32_t N = 100;
for (int32_t i = 0; i != N; ++i) {
torch::Tensor t = torch::tensor({i, i + 1, i + 2});
v.PushBack(t);
}
ASSERT_EQ(v.Size(), N);
for (int32_t i = 0; i != N; ++i) {
torch::Tensor t = v.At(i);
torch::Tensor expected = torch::tensor({i, i + 1, i + 2});
EXPECT_TRUE(t.equal(expected));
}
}
TEST(RecyclingVector, Testlimited) {
constexpr int32_t K = 3;
constexpr int32_t N = 10;
RecyclingVector v(K);
for (int32_t i = 0; i != N; ++i) {
torch::Tensor t = torch::tensor({i, i + 1, i + 2});
v.PushBack(t);
}
ASSERT_EQ(v.Size(), N);
for (int32_t i = 0; i < N - K; ++i) {
ASSERT_DEATH(v.At(i), "");
}
for (int32_t i = N - K; i != N; ++i) {
torch::Tensor t = v.At(i);
torch::Tensor expected = torch::tensor({i, i + 1, i + 2});
EXPECT_TRUE(t.equal(expected));
}
}
} // namespace kaldifeat

View File

@ -0,0 +1,43 @@
// kaldifeat/csrc/online-feature.cc
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/online-feature.cc
#include "kaldifeat/csrc/online-feature.h"
#include "kaldifeat/csrc/log.h"
namespace kaldifeat {
RecyclingVector::RecyclingVector(int32_t items_to_hold)
: items_to_hold_(items_to_hold == 0 ? -1 : items_to_hold),
first_available_index_(0) {}
torch::Tensor RecyclingVector::At(int32_t index) const {
if (index < first_available_index_) {
KALDIFEAT_ERR << "Attempted to retrieve feature vector that was "
"already removed by the RecyclingVector (index = "
<< index << "; "
<< "first_available_index = " << first_available_index_
<< "; "
<< "size = " << Size() << ")";
}
// 'at' does size checking.
return items_.at(index - first_available_index_);
}
void RecyclingVector::PushBack(torch::Tensor item) {
// Note: -1 is a larger number when treated as unsigned
if (items_.size() == static_cast<size_t>(items_to_hold_)) {
items_.pop_front();
++first_available_index_;
}
items_.push_back(item);
}
int32_t RecyclingVector::Size() const {
return first_available_index_ + static_cast<int32_t>(items_.size());
}
} // namespace kaldifeat

View File

@ -0,0 +1,49 @@
// kaldifeat/csrc/online-feature.h
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/online-feature.h
#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_H_
#define KALDIFEAT_CSRC_ONLINE_FEATURE_H_
#include <deque>
#include "kaldifeat/csrc/online-feature-itf.h"
namespace kaldifeat {
/// This class serves as a storage for feature vectors with an option to limit
/// the memory usage by removing old elements. The deleted frames indices are
/// "remembered" so that regardless of the MAX_ITEMS setting, the user always
/// provides the indices as if no deletion was being performed.
/// This is useful when processing very long recordings which would otherwise
/// cause the memory to eventually blow up when the features are not being
/// removed.
class RecyclingVector {
public:
/// By default it does not remove any elements.
explicit RecyclingVector(int32_t items_to_hold = -1);
~RecyclingVector() = default;
RecyclingVector(const RecyclingVector &) = delete;
RecyclingVector &operator=(const RecyclingVector &) = delete;
torch::Tensor At(int32_t index) const;
void PushBack(torch::Tensor item);
/// This method returns the size as if no "recycling" had happened,
/// i.e. equivalent to the number of times the PushBack method has been
/// called.
int32_t Size() const;
private:
std::deque<torch::Tensor> items_;
int32_t items_to_hold_;
int32_t first_available_index_;
};
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_H_