From 3196cff4414569d61acabad7a87b1735f68c98b6 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 2 Apr 2022 12:47:57 +0800 Subject: [PATCH] Add recycling vector. --- kaldifeat/csrc/CMakeLists.txt | 2 ++ kaldifeat/csrc/online-feature-itf.h | 1 + kaldifeat/csrc/online-feature-test.cc | 49 +++++++++++++++++++++++++++ kaldifeat/csrc/online-feature.cc | 43 +++++++++++++++++++++++ kaldifeat/csrc/online-feature.h | 49 +++++++++++++++++++++++++++ 5 files changed, 144 insertions(+) create mode 100644 kaldifeat/csrc/online-feature-test.cc create mode 100644 kaldifeat/csrc/online-feature.cc create mode 100644 kaldifeat/csrc/online-feature.h diff --git a/kaldifeat/csrc/CMakeLists.txt b/kaldifeat/csrc/CMakeLists.txt index 8dce57c..7e6f943 100644 --- a/kaldifeat/csrc/CMakeLists.txt +++ b/kaldifeat/csrc/CMakeLists.txt @@ -9,6 +9,7 @@ set(kaldifeat_srcs feature-window.cc matrix-functions.cc mel-computations.cc + online-feature.cc ) add_library(kaldifeat_core SHARED ${kaldifeat_srcs}) @@ -40,6 +41,7 @@ if(kaldifeat_BUILD_TESTS) # please sort the source files alphabetically set(test_srcs feature-window-test.cc + online-feature-test.cc ) foreach(source IN LISTS test_srcs) diff --git a/kaldifeat/csrc/online-feature-itf.h b/kaldifeat/csrc/online-feature-itf.h index 240612a..7599182 100644 --- a/kaldifeat/csrc/online-feature-itf.h +++ b/kaldifeat/csrc/online-feature-itf.h @@ -7,6 +7,7 @@ #ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_ #define KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_ +#include #include #include "torch/script.h" diff --git a/kaldifeat/csrc/online-feature-test.cc b/kaldifeat/csrc/online-feature-test.cc new file mode 100644 index 0000000..786c1c1 --- /dev/null +++ b/kaldifeat/csrc/online-feature-test.cc @@ -0,0 +1,49 @@ +// kaldifeat/csrc/online-feature-test.h +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +#include "kaldifeat/csrc/online-feature.h" + +#include "gtest/gtest.h" + +namespace kaldifeat { + +TEST(RecyclingVector, TestUnlimited) { + RecyclingVector v(-1); + constexpr int32_t N = 100; + for (int32_t i = 0; i != N; ++i) { + torch::Tensor t = torch::tensor({i, i + 1, i + 2}); + v.PushBack(t); + } + ASSERT_EQ(v.Size(), N); + + for (int32_t i = 0; i != N; ++i) { + torch::Tensor t = v.At(i); + torch::Tensor expected = torch::tensor({i, i + 1, i + 2}); + EXPECT_TRUE(t.equal(expected)); + } +} + +TEST(RecyclingVector, Testlimited) { + constexpr int32_t K = 3; + constexpr int32_t N = 10; + RecyclingVector v(K); + for (int32_t i = 0; i != N; ++i) { + torch::Tensor t = torch::tensor({i, i + 1, i + 2}); + v.PushBack(t); + } + + ASSERT_EQ(v.Size(), N); + + for (int32_t i = 0; i < N - K; ++i) { + ASSERT_DEATH(v.At(i), ""); + } + + for (int32_t i = N - K; i != N; ++i) { + torch::Tensor t = v.At(i); + torch::Tensor expected = torch::tensor({i, i + 1, i + 2}); + EXPECT_TRUE(t.equal(expected)); + } +} + +} // namespace kaldifeat diff --git a/kaldifeat/csrc/online-feature.cc b/kaldifeat/csrc/online-feature.cc new file mode 100644 index 0000000..0e2f250 --- /dev/null +++ b/kaldifeat/csrc/online-feature.cc @@ -0,0 +1,43 @@ +// kaldifeat/csrc/online-feature.cc +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +// This file is copied/modified from kaldi/src/feat/online-feature.cc + +#include "kaldifeat/csrc/online-feature.h" + +#include "kaldifeat/csrc/log.h" + +namespace kaldifeat { + +RecyclingVector::RecyclingVector(int32_t items_to_hold) + : items_to_hold_(items_to_hold == 0 ? -1 : items_to_hold), + first_available_index_(0) {} + +torch::Tensor RecyclingVector::At(int32_t index) const { + if (index < first_available_index_) { + KALDIFEAT_ERR << "Attempted to retrieve feature vector that was " + "already removed by the RecyclingVector (index = " + << index << "; " + << "first_available_index = " << first_available_index_ + << "; " + << "size = " << Size() << ")"; + } + // 'at' does size checking. + return items_.at(index - first_available_index_); +} + +void RecyclingVector::PushBack(torch::Tensor item) { + // Note: -1 is a larger number when treated as unsigned + if (items_.size() == static_cast(items_to_hold_)) { + items_.pop_front(); + ++first_available_index_; + } + items_.push_back(item); +} + +int32_t RecyclingVector::Size() const { + return first_available_index_ + static_cast(items_.size()); +} + +} // namespace kaldifeat diff --git a/kaldifeat/csrc/online-feature.h b/kaldifeat/csrc/online-feature.h new file mode 100644 index 0000000..d64bb8f --- /dev/null +++ b/kaldifeat/csrc/online-feature.h @@ -0,0 +1,49 @@ +// kaldifeat/csrc/online-feature.h +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +// This file is copied/modified from kaldi/src/feat/online-feature.h + +#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_H_ +#define KALDIFEAT_CSRC_ONLINE_FEATURE_H_ + +#include + +#include "kaldifeat/csrc/online-feature-itf.h" + +namespace kaldifeat { + +/// This class serves as a storage for feature vectors with an option to limit +/// the memory usage by removing old elements. The deleted frames indices are +/// "remembered" so that regardless of the MAX_ITEMS setting, the user always +/// provides the indices as if no deletion was being performed. +/// This is useful when processing very long recordings which would otherwise +/// cause the memory to eventually blow up when the features are not being +/// removed. +class RecyclingVector { + public: + /// By default it does not remove any elements. + explicit RecyclingVector(int32_t items_to_hold = -1); + + ~RecyclingVector() = default; + RecyclingVector(const RecyclingVector &) = delete; + RecyclingVector &operator=(const RecyclingVector &) = delete; + + torch::Tensor At(int32_t index) const; + + void PushBack(torch::Tensor item); + + /// This method returns the size as if no "recycling" had happened, + /// i.e. equivalent to the number of times the PushBack method has been + /// called. + int32_t Size() const; + + private: + std::deque items_; + int32_t items_to_hold_; + int32_t first_available_index_; +}; + +} // namespace kaldifeat + +#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_H_