From e043afe3d69b4e219bfcef7a78eda3c6b67d7a80 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 16 Jul 2021 17:35:23 +0800 Subject: [PATCH] Add kaldifeat.Fbank --- .flake8 | 4 + .gitignore | 1 + kaldifeat/csrc/CMakeLists.txt | 6 ++ kaldifeat/csrc/feature-fbank.h | 14 ++++ kaldifeat/python/csrc/feature-fbank.cc | 18 +---- kaldifeat/python/csrc/kaldifeat.cc | 2 +- kaldifeat/python/kaldifeat/__init__.py | 3 + kaldifeat/python/kaldifeat/fbank.py | 82 +++++++++++++++++++++ kaldifeat/python/tests/test_data/run.sh | 6 ++ kaldifeat/python/tests/test_data/test2.wav | Bin 0 -> 16044 bytes kaldifeat/python/tests/test_fbank.py | 72 ++++++++++++++++++ kaldifeat/python/tests/test_kaldifeat.py | 24 +++--- kaldifeat/python/tests/test_options.py | 5 +- 13 files changed, 207 insertions(+), 30 deletions(-) create mode 100644 kaldifeat/python/kaldifeat/fbank.py create mode 100644 kaldifeat/python/tests/test_data/test2.wav create mode 100755 kaldifeat/python/tests/test_fbank.py diff --git a/.flake8 b/.flake8 index 0e88669..3551e08 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,9 @@ [flake8] max-line-length = 80 +exclude = + .git, + kaldifeat/python/kaldifeat/__init__.py + ignore = E402 diff --git a/.gitignore b/.gitignore index f9f78d8..c697d52 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ build/ build*/ *.egg-info*/ dist/ +__pycache__/ diff --git a/kaldifeat/csrc/CMakeLists.txt b/kaldifeat/csrc/CMakeLists.txt index eddab10..d94d43a 100644 --- a/kaldifeat/csrc/CMakeLists.txt +++ b/kaldifeat/csrc/CMakeLists.txt @@ -9,6 +9,12 @@ set(kaldifeat_srcs add_library(kaldifeat_core SHARED ${kaldifeat_srcs}) target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES}) +# PYTHON_INCLUDE_DIRS is set by pybind11 +target_include_directories(kaldifeat_core PUBLIC ${PYTHON_INCLUDE_DIRS}) + +# PYTHON_LIBRARY is set by pybind11 +target_link_libraries(kaldifeat_core PUBLIC ${PYTHON_LIBRARY}) + add_executable(test_kaldifeat test_kaldifeat.cc) target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core) diff --git a/kaldifeat/csrc/feature-fbank.h b/kaldifeat/csrc/feature-fbank.h index fc40359..80a3ba9 100644 --- a/kaldifeat/csrc/feature-fbank.h +++ b/kaldifeat/csrc/feature-fbank.h @@ -8,12 +8,16 @@ #define KALDIFEAT_CSRC_FEATURE_FBANK_H_ #include +#include #include "kaldifeat/csrc/feature-common.h" #include "kaldifeat/csrc/feature-window.h" #include "kaldifeat/csrc/mel-computations.h" +#include "pybind11/pybind11.h" #include "torch/torch.h" +namespace py = pybind11; + namespace kaldifeat { struct FbankOptions { @@ -42,6 +46,16 @@ struct FbankOptions { FbankOptions() : device("cpu") { mel_opts.num_bins = 23; } + // Get/Set methods are for implementing properties in Python + py::object GetDevice() const { + py::object ans = py::module_::import("torch").attr("device"); + return ans(device.str()); + } + void SetDevice(py::object obj) { + std::string s = static_cast(obj); + device = torch::Device(s); + } + std::string ToString() const { std::ostringstream os; os << "frame_opts: \n"; diff --git a/kaldifeat/python/csrc/feature-fbank.cc b/kaldifeat/python/csrc/feature-fbank.cc index 5122605..5f26a6b 100644 --- a/kaldifeat/python/csrc/feature-fbank.cc +++ b/kaldifeat/python/csrc/feature-fbank.cc @@ -19,22 +19,8 @@ void PybindFbankOptions(py::module &m) { .def_readwrite("htk_compat", &FbankOptions::htk_compat) .def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank) .def_readwrite("use_power", &FbankOptions::use_power) - .def("set_device", - [](FbankOptions *fbank_opts, py::object device) { - std::string device_type = - static_cast(device.attr("type")); - KALDIFEAT_ASSERT(device_type == "cpu" || device_type == "cuda") - << "Unsupported device type: " << device_type; - - auto index_attr = static_cast(device.attr("index")); - int32_t device_index = 0; - if (!index_attr.is_none()) - device_index = static_cast(index_attr); - if (device_type == "cpu") - fbank_opts->device = torch::Device("cpu"); - else - fbank_opts->device = torch::Device(torch::kCUDA, device_index); - }) + .def_property("device", &FbankOptions::GetDevice, + &FbankOptions::SetDevice) .def("__str__", [](const FbankOptions &self) -> std::string { return self.ToString(); }); diff --git a/kaldifeat/python/csrc/kaldifeat.cc b/kaldifeat/python/csrc/kaldifeat.cc index 3398f8d..ca4bd79 100644 --- a/kaldifeat/python/csrc/kaldifeat.cc +++ b/kaldifeat/python/csrc/kaldifeat.cc @@ -27,7 +27,7 @@ PYBIND11_MODULE(_kaldifeat, m) { PybindMelBanksOptions(m); PybindFbankOptions(m); - m.def("compute", &Compute, py::arg("wave"), py::arg("fbank")); + m.def("compute_fbank_feats", &Compute, py::arg("wave"), py::arg("fbank")); // It verifies that the reimplementation produces the same output // as kaldi using default parameters with dither disabled. diff --git a/kaldifeat/python/kaldifeat/__init__.py b/kaldifeat/python/kaldifeat/__init__.py index e69de29..e177288 100644 --- a/kaldifeat/python/kaldifeat/__init__.py +++ b/kaldifeat/python/kaldifeat/__init__.py @@ -0,0 +1,3 @@ +from _kaldifeat import FbankOptions, FrameExtractionOptions, MelBanksOptions + +from .fbank import Fbank diff --git a/kaldifeat/python/kaldifeat/fbank.py b/kaldifeat/python/kaldifeat/fbank.py new file mode 100644 index 0000000..5196956 --- /dev/null +++ b/kaldifeat/python/kaldifeat/fbank.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +from typing import List, Union + +import _kaldifeat +import torch +import torch.nn as nn + + +class Fbank(nn.Module): + def __init__(self, opts: _kaldifeat.FbankOptions): + super().__init__() + + self.opts = opts + self.computer = _kaldifeat.Fbank(opts) + + def forward( + self, waves: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Compute the fbank features of a single waveform or + a list of waveforms. + + Args: + waves: + A single 1-D tensor or a list of 1-D tensors. Each tensor contains + audio samples of a soundfile. To get a result compatible with Kaldi, + you should scale the samples to [-32768, 32767] before calling this + function. Note: You are not required to scale them if you don't care + about the compatibility with Kaldi. + Returns: + Return a list of 2-D tensors containing the fbank features if the + input is a list of 1-D tensors. The returned list has as many elements + as the input list. + Return a single 2-D tensor if the input is a single tensor. + """ + if isinstance(waves, list): + is_list = True + else: + waves = [waves] + is_list = False + + num_frames_per_wave = [ + _kaldifeat.num_frames(w.numel(), self.opts.frame_opts) + for w in waves + ] + + strided = [self.convert_samples_to_frames(w) for w in waves] + strided = torch.cat(strided, dim=0) + + features = self.compute(strided) + + if is_list: + return list(features.split(num_frames_per_wave)) + else: + return features + + def compute(self, x: torch.Tensor) -> torch.Tensor: + """Compute fbank features given a 2-D tensor containing + frames data. Each row is a frame of size frame_lens, specified + in the fbank options. + Args: + x: + A 2-D tensor. + Returns: + Return a 2-D tensor with as many rows as the input tensor. Its + number of columns is the number mel bins. + """ + features = _kaldifeat.compute_fbank_feats(x, self.computer) + return features + + def convert_samples_to_frames(self, wave: torch.Tensor) -> torch.Tensor: + """Convert a 1-D tensor containing audio samples to a 2-D + tensor where each row is a frame of samples of size frame length + specified in the fbank options. + + Args: + waves: + A 1-D tensor. + Returns: + Return a 2-D tensor. + """ + return _kaldifeat.get_strided(wave, self.opts.frame_opts) diff --git a/kaldifeat/python/tests/test_data/run.sh b/kaldifeat/python/tests/test_data/run.sh index 53b3f39..361b657 100755 --- a/kaldifeat/python/tests/test_data/run.sh +++ b/kaldifeat/python/tests/test_data/run.sh @@ -11,6 +11,12 @@ if [ ! -f test.wav ]; then sox -n -r 16000 -b 16 test.wav synth 1.2 sine 300-3300 fi +if [ ! -f test2.wav ]; then + # generate a wav of 0.5 seconds, containing a sine-wave + # swept from 300 Hz to 3300 Hz + sox -n -r 16000 -b 16 test2.wav synth 0.5 sine 300-3300 +fi + echo "1 test.wav" > test.scp # We disable dither for testing diff --git a/kaldifeat/python/tests/test_data/test2.wav b/kaldifeat/python/tests/test_data/test2.wav new file mode 100644 index 0000000000000000000000000000000000000000..0016f503339831f22535770fb401321ced422b41 GIT binary patch literal 16044 zcmWk!Wq8|26Am*oGcz+Y+KyQ&Gc#TpuZ-8;m06M@hS*MGhnbm~85-s$`SLybC5=`y zv-9rE>Yc?#M@0Mz5E3E<6T=UlIB%pRBqSufb=}%}CUjj$NJL0UX#bw`d;WhfC45T8 zUft19&Dz{0%Lf}`8(oWffEy?J0M_6XKqkK>JV`hlXA&_IknK@#?{8+L^Gk7D%w+Y@ z)YL$2TSzUj#6L$$(8Z}>Oj2WzS!e`uL~^LF86P=?{IzV;5`)^3HqpT!Q}j)lh@& zRz1$J&x??2(Fy1k+-}lefFB3}$)t1m3AA$zKJiwP-`4qbWHt`jD%w zpIBH!%ni&0UY=YFxPs0=6;TBzi>in`9i-w_@6c=(sS~THELOPu&m_Iivt>^euFxq{ zH^q_tfbLF-hKr$8ScO`fIF#hgPtNu!fop5q-wb8X*sR}|VyHef_-(~@fqZG9Utn~Z9z8qx$A!V`!) zeVQenvNN-_5Lf-P_0xd-bjjKSNg35ZgL$h}7i&KtY;UYh!heKUWF=4XrA|xkfI2LA9 znYV@vv`yu#Q6&kgNe1_=Vb=ja&2XK#$^>oVAlU}&0))t)@LA|9(Z7Sez3(_KnB{3} zD%gp*FCH2D(`DA6R5qD2#<$`WCw5cr!}-t`SedHEc*D6Y_>ntaX4~l6Q#D??G%VVn zctQ7zg_d)QPi!bVCJ%FvaF47E3V}qj1_6!HjCmYV>b>rmV7{VluAnILdtqhtNZ0xL z*iw@mBmRjb>%`+*5w3ui;6kbmQ<=9reLf$ljBcUzzn-dElb4cJOEN65QE@*S@HOID zyb?}{ba`v;0ic%1!b--^g}?TfcY9$SYj8m&LSktpe{!JrpC-wQvw0$Ey&PAD3iThj z0HVOY)bor$?)5aq{PK!h%|iXMQ>WJ2Bt_L~4W(_*yY~bbN6Mk@;0}^tz!Ovj0*U1b zrLp+1x4sTP4of2C3)R3G^C76kdQJM4qn4PD?4y(k*#d z*VfrO8of9yLR7q@t7bXt-0AxxOe0PedzI+5W%W%!g4BUki?0Z`^pAHvYgMlIN=Za) z=ThHTP1pYV=8~P+cBx0%K6C-H42i%^$O$@+{U&uV`);XMLv6R>M8@)Uad8!Ig9hub zZhHd8B4kkCa4$)DfE`#y{y`|f#Ka_ohY(r^CIrZoF?n&-<0a`gq(NhXQz6n%IE`VTRoUtnKO70DSX zoog8Hc{KUYDqZrX+K3TkchS=*_-wQ<1|+byke35oAYpN*;zz?r{64$*TT1B$D3}S0 z&KnOGw7FDM3U;Tz-qe5UCQaO5DV8Wx-EV}li}BP9 zz8Ad%bCXa(b^?C^cBE=7Ev`N6neV()w0V_Qq3qQSkD1_sq2?nM@_90Xw@Jf0vfc`H^^#Vh)c%>Btj$WKtD>KKFh_Sj+DLkJ-je zK>n7Fx@DQmPXAXC&M0}jI9UXo2f9dQxaxTM2(I6Pvx)^;dqEbrkv@I7zoW^iT#%EJ zx`W+DLnAfN5m{u)(BUzz77$m+@~xr}87s z)Paw)eL^CNxq6k>b?*0rzC;gW9Ecde8yp}L2?%;A@^Zi-R~gGBZKmw*b)%{1Uhjs` zlH|;KUMpjs5(N)HmBch*{Z&GMd`AEyuW%{L$=)tLg` z4@L&X8vX-$QYsQZbFXFC7ayu`?9rR*Tz@Ust8>dr-rX$-iMC9*PJFPnwvUh_@qO_J z!)1N79ea$^Rb|CJ7Jm+>wANLib8n==taU1VYgO+?_S4gnJkq`uFluDFNE2>rnliA~ z3yb5fBLU2)>lhv4?k$|WCf~suqdtUN`Kmf{jCZTv6}!7oF;vu2UyjRpo6^fXMLh_2 zLG1`O(U_~9@u%oYy?<}l)bhrL{6BhoZ6rL+LXO9N!rmgy1N}e(DH=N;b0^r#=EEnG=+;Q>`t>{AK-KI3ai39D#w;q0;R3A4qCi)4GVOL z7E5?a`~(<)^5l!S%W+YmR$hl~59vQq2;N+ovgq}!e_8Y<{SBv{ejllTl3`z37#rdn z=O3yz=XQ+viz%RMFNe2fg%k2C;G0$7qQ_I*rqkd6C<6~PrdjE)rZa=u3mW3(YR#q<&Yl7MsAUX~m;(rbXGm7q*cjcQAy*{}am_$U-=%BA zw_Ebd*jX%|5<`X31wDl2XlS;9z_q}&me(CJwPSOiVyD3Y`)fW~;bka*;Jg)mfSicK z$2|)9?D5-bQd>i2bY)`nN895Hog6`OCgV6|2HFQ_Q^#36eqH`(&9m;%smGfFMY7?V zy@~JMi0$ZN!V$m%oFffjjbcZFUbtPiFxQlpWH0_0I^2A_bbsa*u3F*+qySPy2=v1o zl?pIP= z_3|uIe=)l$a&QQIpPIt@#ox$3TzjbJ*0iqhOXVZRmQGRr+oCoxBgB579^ez{nC{5$ zehH3yjEGA7O`R!)ZsVG!ykDumnL<=mI1qk7tzdQV_Z0}$dG_9)!HA5gG@HD0Mgm?% z&nMg`se$)_PJ%zWBV69c%dSQ5oZRQN@v()rh>AnmrM#-dETj|?Mt0E!oVtwIQbF_E zVd8SEbc6O)YnJCsXhZxTyw%oDQ;AzxdGQhT6?2Z;<6!x?X52TjKlqVcjQbr+4BqQ@-#k`5T1;Y8bmhFka}~pgthEfRl*UJZ7oy(H>$QPr)l})oCg6PpnIf zc;G!}dz077aYCX~#IBL5m|UTh3yeQVCG-crMR(xlX4aJTw;meHUNe^$GAOif^i7KF z!zhu2z&Ahzp&Hc@hV??M4{N(hB`#hYjA`T*MW)?k3#fju6?}(k&F&F|7Y#I658hlf zmfEKM%X+)lOxOf!l<*Un0~(2kFar@ZA4fYOy>uCwl@G%@%{C=_(*2T5X**ygIE>oF zdLq~?6f{^2LW^gm8nk_EguUN}Uq@qz6+k6`Aznn^3vcsUus*KcEh)PAU_h_oXrZX! z5NnyT0zHC%(M&mljE&M0t^3FP|ClKpH@xM@^S>FrA1h1t1nbBnIBM*npf@fVCO4Iz zZGN6S*zvxiEh~llojwgWLOUs^nANFz1+8_*`t26bk_y^2)K&oSY{cZ?2FUi{OOk1VeWZcUT^m{Lza&1+Mf670Y|nL0 zMkneb4ba~dEmn}=NztJuy^*2S8hK@-DJN>+`&duBF`x(PlR_}<;bbqI)oqPLQQ9=2 z>w4u-mOFQzb_Z66IaHUV;EX+GzuW95Lxen3vGKa zw?f`p{zy%epvX_K$lOaQhe*hDB9LO1pIvuy;OJ6@%tiff2djW@F&x|+*&OsF#bQ|D zU7lAgL)7I&GJd;v;L5u)?j}`I)#0bG1HF(to~>5>ch6sQa}vTjJ+^$`+Nj6a(_|m8 zmXwen7J0y1)w)UJwy4mwU#E7(fy}z3-&9@r5$r-A;JW6#sP^jJJufNcs;h55?AII} zk82=XgI*+aOlY{eXSl^TRcWD<6FIH@C6u&m7MoHB#UgTy$`qXfOvA4suhms~QRC|_ zPlK<w^c}7W` zr)6Y&$#*oW3QHq9f-xjL%zD^g9--!*DnT1N#vV3vi`aY*X5|(VeJCod1;MtGN3HaU zT%jbje=VPS@gjC6tdsn~4bqMTTEq#jZI&oCZy|#Tq1NW&8v-6vlu`g`P`a7Lg3ywG zTl*%=g?-iitpdDvL=GmzkiEhGNcS-0aLB{bJYV_Tdhn=66RL1usxu=1>45a9ZER-7 zX1PdL0`NP4fi2Vr$EOFiI-N2&Dbu~s&`Yk~n_bLtp=rTM@G_l} z9GEZHU^Dz_%~7etlwn-e9G83>^+ zC7sGTS*6^2V8L7FsDZe%bnusWP2vV%0NlhKi$3P-Z&RY7EfPEVS1Z4Gjjzajiqt~$ zlxp^TMqUN6n?IK-Wv(yjL=19@?;|V#w!m}T(dcO3-`3F@M};p=JZmW{GUuOQ_#lJO zUFu-c>#W|Yf!+^`1+p$ieJs2UicODv=Vfp4NV#+E_|Szqqug^`ADSpE zhr}@U^4Z1cHmj+1u_HRM4mSfu;$?{PU=?{ARxR?o=djtjqS5NRfy9~z*)E(5)K=&% z@|Jlb?L^tv&Y{^MsVIX1=NBO;^tLTD<&qRJ??N?Q9~r92h|a@Z;pKAa3(O3p9s;TQ zoVVFaHLio|Yc|T(7K2`1Q64yVpb?NJ=)`{WyKIxDPTBl#w5Ops?;!6hO&JzO4lp|S zoRaDG<1;yucl2*KUko{bzP_d9A_^B*?`W{;&`)pmm0Za z5yFV(V>=|EG0vH24DKT(^tOcdLeQU>j!QaLKQaJ~JghmgMJExWYFT;La!hIQ#Ke3f%CJ>oeS`2cqZH~~cBE=QJn z>@fLHR%G6-H}Zn{U;6RN>b}rbF_lYJrheaIy9p+s0eKp27n19^uDw^}-B?_M zN^TfuW-HIEP)HoV+_w!fEN9hA$SaqtO}kZ6vPz!e6XM^epoW@87aFWP-T3C zpTMJu0Rnc}O82{^0>xB|NuSA>w*+0#jy#0!4c2p<(wY(K9=TA5&+cQ}Q5qlx+7sTB zf`Jyp=?hZoMn~LxB0O<-0W|Od+Z8V9w#VRxgu!pw=EXdKyNg-~)ltx#=efrkFHOvf z&FR&Y7c%G_)PwY zxf~MVctNXpQ*$`J#w=5l<%c-KvWYwSo@E|ANy|T#9jrh2qvQ2RzTh$9LhLW!REsi& zl!YUmMa9=rhH2x_Jkr7*&lYZYHqI6U^wV4oB9w5i0X=}1FcI45G^{-%G%%c6;w11QP zit;*cFGwneS|0Qh#LW_Y!Fb}O*w;SGX8YyJX9L?Z3jT0+Q}dyVRC}IcVS3x;IXn4O z^8nwC*v~|N@ENf+cEu;o{2#fRS^c)_1p=-!wE)VY*6`Gd^g2uzgcKiJ_WJq9$C8{u z4#6VkwwJSsjI`<0r=}aZPm;7KQ&2dalB!WA*86hxqS|LWy{oC{>r;`mK=J4%A$F%iIw;`s`ADt-V*OJIRHs;0k~{=a1Id@&a&Fg>vetC%XCvO(~6<1RR83);>GTTRlLR<$Maz;*eW2J z`~-E`pJL%5_i_4ElUYt9OAh$~|6<~^JR72ZFG_zg>+ro9-$_~m-r?}!_ndSzudPM( zsFc_wZ=+T~4KyLXe1+B^PsmJf&<&2dP4EFhqJNB-=R*T+(e1;_l{*E`>D^EdC6*^s zjPA}~{izY*gb8QiL_mnt5ck^WrwKt4H6B)*m)@JG4Qo>9oSK574!7lawLK0ip$phf zK!p4({-*CIQ%@=3iHus2jKxG9Sf5hMu`Ar(DY6o*zTL4Y>>*AVd`r3$*Xn)FI7NK_ zD5_dMEq_bzI;EOtS7O%dyM9O8$+apnbjv*T$}CboZTX>17Z&JO zp+*M8`i!1=pNadn<*y4)By7~)RQ1r}k2W}OoE?BbuoFun=VY_w+=!xzZBOuG0t*Nr z7sR7|b4@-;6py-B(*!%{)zBL%CWTY>ZQ!8rxIXBq6{|r~2Xk>bVgEV&tNQoisW$(- zYivoR7gl8d%FS<`Suj%da=0Hx$EktJq~9^)9wz#sLSX-y(x_x1>KL>uF*!Y^u3*wg z*2~g0@Ggc3aL80tgx?j@0f}d$|5f$#3u*b#FdjQ?O6QS9@l_LBzbaG{*&JY@|MI_Oc0&>x4XncPF|;CRt2(6dYraoFG6yW{0^Jh6 z0tV#fI7#mwgD~Nn{R1T@c>A|Fu1$t?_U>lP{9Tpj4kyB&;Jv_7LT2Pw=P~v4rSolv zau2cYz?sMjXQAk1@1adlf5tN)Zj!_RLKD6R=2%>mRvlNU{>YD|l|morI_Y=o9H!Fc zzuM4447X-WKs*u6aTU{gzEa!npI5|s4JRXN++D??eeZ<|40e0(jlW1v0P4^x{vM`5 z;*=rb^1I0?6g^mi`68Rra&XZ>{j~F9q$bf0WaAEny4a4$FZ>p+JD;XSFMxFDx6{n) zLZ{FQ|Jr7UPT*X?DndY{pVL*}<+Hl~qEt&_Rxo1%dP}WK z9jFqWcp&@C`bo%XoDJAd(24xu7^*^ZZkGc-kGRCuI+oqSlYk9bP#I|hpsFX2p{can3i~Q&l6s;#IAw;B{ zBUFMOFcgb3l?xC}C<8)Cg6K|{=jyi?@|yxO9TQ8TPFhXcfAy-fF)CV4wUNh&(x4-@ zHIQsBFR?V}T6&gyA6bOEliG@W`x8WaOcwq15?TN$(q_~%XSQnC++Pjx>5}vkC^Qku zWH*a1ozTp33yT#e1Arp>qK}_}?1o0y^E`Q$6|6wrml{?5ZIY;VPA3jAzZE%IOQ+kys`H*7|A;z!g@O~Wc|=cedX(r(}dD$h$)_rgkb%g@Y` zt@vF^oXh&&_GQgf|DF#I%>raehEd{yS@75{nGq%&F>zJ_z-|XQj4ZKS}X3KxLXm;BOvb)_2i;TvERc=j?Q*5 z`s6Y4sknP?scL02Yc*b}2PnDl8;*8a>1df;sr^J`5a}#niIVsHqWSN9L_MGXhjI~q zlLVJIjck|Y*@Z-kZsGF`YL91|CT-ri{yD#k@)%x9I#?P$`c7`d-ZSbvDH_-x|Hl2c zy4GxUO=XH6MH0#5;wly=?kXK~PKh~BzCo^xx$1mQsd>V+;wyI<_M-%+vTBdcscOo4 zoIo?|UTJlbSNNo(pK4I?n10qX}I#BB?VG#*}WX`^RK(Z?V= z7NM|zpj67X2;qKMWY$wEojGG?bTx)B~aHaXel%$7cUnhiByxbnf{|m$sUBY%- z{1!IvBxbeI6yYXzLMeK@RcXJgcKj{i4Do2#eTxaB2w=8j^Gcg#MmqH}yvk9iw4D~$nDmsvXo4~LpMlwiKC7vXRK7oD4-%ddQFnGL zeXRN|V7Bo_I3beG+rL|dG7De>Ib|EF|Ob^*57ZN z&a*k$YKVIkR-kb+OFLu42CSKp24qe0e~~onG0{VvdosJJZ(%E*QjPqAv>w~P7^eXq zLua`UDE}HeTs*)mgsK?V3dtih#c9{Qs0H98z9}G3UwJXF<`VAOj?E>C6WssjNDhQ_vpm2{+r*ML}pQ(&p?{aK}o z=j=+zh*?s+Z*oxMn(uSm3a~SNmy@=1clU)%d+J`~2EVW6yYRF%KH3zRBb*8{(f@Cb zS&_3fDu5YUg8gl(#r2EC=L5^Jwf4p0mpZD`G${utjP#05K+?mpCvF;eh!gRpslOke zDI7~&hf+9%n#0SACLv*8Nu8wMVZ$cgE5fy^TuHbvv9T~`TvOwgZ!fMM5Q{r!-zsL& z{zmW`v7`!R{oPwHd&Kn{+81oZTz3ytsOkTaT}rh_PVjlHFGTO#y^1Xcrg1pmqw1`& z)B@SW0GP&8ZaBUkVHp}F0~BxhOVGJ9wZEj9F$5jtJgm8~(rPLfag>}++#YP8?>BqC ztcO_%ZA%)f%3j)UVi9glmLRPM*Xoze7L~1TNe(AjR68v%noz>Gllerg;1hbbGq_Si zW}tv*%GoUW!3!z|zRh@L zfD-|l1T3_dZ88*K&E$mUdC?n=6DWW1CW`LRC4y->$=f<7!i=jxEnPJWjJgVNaDdkZ zh3X#H^xep9T6KQm1V`t4@F&tR5ee+lY#n`*`-G}Qd6ao#;HnDb8%IbXFNE0`I?gB; zchd7SCo6YYJlX zUN;AepLf22(FeJ)s@7-LM5~#sJ}8nW-`p$~L03+H92iUo!%+@JEW-B_m2eFOUh;6|cM?#vmKtkFAR$249~o#H(oXTme` zr%)}uhKbs|U({HNP4?#SNlkpv1JXal6#raR%)q^jlgI+iwCK`Ip3!8)89*EN!`)Td zv5lL&0-a)AskU5yX`37G4_=60vN`ytwn~O24(Y_%QVKb{$h%Jc zmvSTL`q&E{$VIqGp1LS%^y?S!mY22$Gx*At}-jq7)Rhvetox! z(g$BjVmiq=$XtUlSdoE6?$8}e)E89DQ85}|5t`(1a+6l`izNzwNuhQ~$=~x1CpeS$ z2Vd1>4uxgzLvGPemZU8#naju8f^{ggUB(~1DjVibSWb}BLsyRT?;%zZRsDmMV|r5g z(eNXtOtsBMzI_H-9()yTZyGv#s!)w~f+C!IYobQqG~zy>i`8%u6BV!jn>_>FN(t$_ ztAO$iB?giv{QH%U_gv$7f;zidsVb!9x^`~>jhqp zm$DXFK34i@>vTtBmerV@-sSL1z-mI6Gbk)vYr+bGjRoty`fBrmD&%?mW6wUR^Ud!# z<&au(W_zkU(f1y)j40}NOW}0K<>amNCY)W(n$o?VPJ|}XbU?UDQ}@l(Rj8Sr+7Ka8 z=uX6w$d`iL)gJX;<8Lv8Y(YIr+|XSYe~L^GBB)*O1^J5bH8$AbBC+cJ6mLd;5%^7o z)2)*_3{jKdrX?w~*MEfFBn5v3MZ1m$UI8@7O=|m3Uc^s^G(!;aej_c`oRu^R9ZCJz zBdpdC^nv^c7v$_cUjNw1m-<(Da=CaOMeo71?o=P+j3aW$nPXc zoIrEm6rHV^W~{@vvu;gT7$1u92c2Sfnq*9pvu%-=jJ-8KM4!7W5)w#PeRXAfni|=T zFp}mmOw{`~QWM-8-(apk^CXvpd`$da)hB$!jf=lUD)zOJ{nqrDy~XIImydWDtVX$k z6|n^-*C)AI!|;9P@45)d6|WD(WWr+)H8JxVbH+Ucn|pi)vM@o-0yW|7T9W-YQ>!2> zZ@iPPN($M%_4nUu)4XuHU>~KJez}q?Om)*F{2_kwzAE{x&X8$~5OQ8lV=dfK{eW_~ zzvfcUL~;RiB&EAIT&pyE9k?2wZ}#3GHDihW$kqIZ3x{l75;n*UfkKM^v=p(A!zVH$Ce_Ri#Loeu;RKD? z?z6l$NLWxkBxx`g?F$N`rgfo#cd5nD?_`<2W7>9+cA!UWkKy%^dBF(uhdbW2Ln9_! z3`F8wO|r%a>9VjGcTeYp+Uc<0fKB|Y$Fn=UAtaIxB1$D2+2VA&#}AEkVUx zsqR$`u?R!(dTfdzZFrMk1HDNO@4KV}MEQYc(Z6(e_j{(4K@a$D!w(E^$7+Mm!kg5e zccyb};Kq#MNl^@P`JW@rd}on$;{ioSZuF}OCmA6JH#}rl<92hzhyk3OCA*B zJ$$tjXGNi~XRBxPz9ChP9T*M7qQ9CCP86rFK~X&MUR#~~Ev+{rrL=mwq_|G-znM)_ z&#VS9RPql$ZKzb2x)5)@!Ca7i|tshiRyW zwLW73h-ty%N}$Utf(tUk?dn5435j?~m{Xy`p+FFLl05;=;Fc1&dXVDfn5! zpbKHuxF8hSWTmwEsY2qu(gr*`Qs}^E-Qo=ctyuE3#go9@EGTQ&qv*= zm`&ha$OnbkhN47uYG2tGQNGVpvOgx?+<)vTzX;MwE1VEmo==!1#d^64{V6t}+-06` zp{e>s_=3M9=xW*Cc-f|m%{E+O3_V!VHA+`hEbNm=VyEI=A{Xa4Cfa6zir zk2~*ao5u))A|V-aMYTlwV_JRnzcMd_ngR3Jd7YFlRFW$4tDyPM0gv}2I6>c1ay)@w z4RxjeoDQ>F!7UQ}oUw~(*_NZ+`ty#V$3hC z{q3r(D2i{HgE-><6d2gD=zOOryB%pQ`6Bw=Zy#_mu1eRiD}r5!FiKvDruv-(JmTDR z_jcB^Bq;n+dGRTK1Tc&>()MlVG7*Yog|d`a&8Dt1M%5uFsne2`_qvzjY z$%LpXwTF~&krLo^K$U1xQ8j#-S2nU_WraV8|I0dlG@X|Re<|7!4GdHUk%&=+i`AQy z1FX~Ckw(Wb`lLpu*crTF0gA|VSik9W70{0HQ){f>L%YSGwdd+pqq<0XuJpNYX&q2> zj_IET?^A$Y%%o~)y$Vg9>Doy#m_~b&-ZF?>__T7HKR-+u$(My$%$Rt8hTlWYfC z&3EG{_-U&n!*iTh$dU4$vM%8U;9P*7h(Un^e2}j;BXCY4(NTwV_q7@%iX^(XVst4e z8tJg}h3UT3SU9%OP>d4f1wCRtXjRg( zgVE%N! zQ*1fD!{U9P8_SiFkbQUW9`_IOz8X-iiH#y( zb?%<5;QfsZRQ|2p8}l#ui;L{kt}TuQTPdc@j`>Rd>eBbyG5I{wTiK!<6uXan!+GCi z5U&aOw^~^3QCu_0#qssnH_lti&AKa^7f~HVs_mJf)9fZHr^#8b2@^+XvM}xKWt7sr zIy6mg;)b#R8Zp}J=^Un4zp}L@@d&y`%e(#;rGryE-sg0Z{3PbCqCzXGV1&$B3z44gd#e~E79j4bi z?$G&6|A7e`Nz%^vI@RS$1LSmS@4Tb;AK-ULfQWfcImFM8k_?JC4Yascq>zC`w}gUr7*1OpiE}Q1S?dR*u&8o!RzCJ4&YT^gQac!P1vGpRqYJQoM>qo z&~2G`ky$zLmt6tLBKEZcy+j|@$XwdE6f6%4dGDKRO&&qux~Dp%ghAXTBi)wM)bE^z z$#l23z_+j{(TeOY=x`BM&NXI}lxnwQV1Y52DA&iZxk4hxsw;dho`vkPn?>G*Jp@Kv zLnj_3MbdV+KQKcGFH!ZXq7^6Lq4XUaox$S(!y|g?H|II^M(dP`8$k_yUk$q@)yWXv zd=yd%klk^U8A&^7fgK_i08u{PTj^fO98{j2DRwbpj$Gh)c?iQaX4LoJvcFCqiL8^* z&Fz3b7C%xPiHjyiny+1|<3H|X{#rnd0NWG|7(FRq6A2%4TGeS1g$U^(W zi^$LHNIlK8XB^4d`@U|VgxBTCAoj~dhe12%&%n;$=9QcjF6B>ai^VNcUX-(VW9Ccv zM9rig34aJhR#Yi0h0d1btNw#|gf-EUsn~=zav#dRh&xC4WO%o(4Q|iaA@(uq5UJce ztVNm9#V2owhW{XQZFD<#P)Cy=txN{*1kO1u_q?Q4aEBLe2j&9?&c*$=>Bl)2=34#d zfjpP%gPMt~B>Pzje+lrQtJ+X-;?pFRSrvZ~P}%kG!Bx62$71e}{~(a%+|nOI*Weym z$O-%o>~oaq4W)JQdX}R@^nn>$>rN5s#nfl(SHd00ILpu$4kDUUg{@ z?zh^bLSZ<*MqqrAbU*yY>N0n6i;%ML)&vKfc6Db{?9xc$vT@6}YZ|+X79qbHBa z?9HZd)0Jd?qNOm^GL{)1Da77-H#cy^sjK5Y@-bUq=`xlY#}>Vi+D9`TJ?HxrSh2!4 z=)K;SEGeD!_rmkA8RA|&xa1EMQXgS;3yARrM^4gG`7b2yp`Ky#WytCN z)ZBrOUTr|3&GtqM_(?HEcbMoL<~?7>yuevo(~LfVPgHfx-Hk9iMO;0>QOAI`4A`om zUvr$G6!B$|&dN>FTI-Gi@Up7gb9e~1+r?cQOt(AVWDPr%?A8B9`Y-s<)R9Erlo8SS z_+6MbDGNb1O>+EKz-4lx5v^!JUmw% zZ}^&==XZN3m!gv`r}m4$2)i)jmngs=kF}Ys03ZCb+rljpsQ)(Nw4m z2cKy#bZZ0_&CgZH!J@6nE)w8`rEb*(bg%i6vktguHC@#MO*JVx8-SYDV^v!niRL@b z>LA)mrfLDY*;3)M41BP_RhYu8wlTL9VA`~{^akA0S?TEoL>Rp;GDpPwNWOc>Uv&j} zUKG#aOM#N4ZuMga2WB5}F#5vwwc*|ZY1JW{L=uVV(fROjN%0^$vUEiMZ?eKqO^=C937CYEon zyhM!m<8>**S8Hib9AMgjn>$4Hm>r5XL`(g_F$~jXHNwd}&y1FQNT(&pa~nBCLofXj zqx_Em-5oVKr9$0RAYrNxp0WK|u7U&)UkVMurHYlZt5d|4L`fFDk{#91o~HdChsacQ z`?QnHh0VVc2E*}Vj}Y^+TQ)}E856hsa_Ymm53yNspBE0(dUGX>l|didD`n4+nK5<* z1A|}R@+59Dbe?n@oT;Z3@#k7-J zu0W*o`RXKCb+|0-aYEFFD1(v7(5nQ7>?X@n5c$cPsNE=?B{D4{A7kbZikTPZ+tF$k z_n|&SB}`mK$Q6?g=Kyo+UQX~ujCRqgA5&~5$x+IR=yoWw6VzfIl1(}u2@t;|~z%Y?(;KO0WM zhNC~CzQl=4KS4UGV%@%z+~qJyYH2rhWx*>J>Upt=142anPrtfmBY0_u8>t%aGJOtd zsxEQgAWBMKWhrEyH3|R^8m(u%XFZm7C$+desOm(DX8EXq@Oym~(8tcEV5x+^mu4vX zM%rOPTT2DNJ{7q)6@_a|c&mrql=tKN8G z37eb0Xta{QosW?8B+VI@^4o07$lK-fSq|Aei!VUB3XYSS?rNd{+G)~w;(|j4USOUM zD`ho>r0Wab(MR!bC(CQwgJ%pc37+y^Y3hUj7#XCg^7=J2z)>UrG$o#trZ#xqs7`Q? zx1nhbx*7KK&n6dXdxIwW=TnDMw(0qTzB+G{XZWdxdf;UZkPD^PnPvmE${lR&?6+1D z09G!T3FJ@P7m&upY4qnMC9WNWe>c$7k}5GDY23+05V34J5@dzGHK_z=bkQRWV!Hd= zA^b>5+{sY!wjc1mx#@%}zCWs?sV torch.Tensor: + """Read a wave file and return it as a 1-D tensor. + + Note: + You don't need to scale it to [-32768, 32767]. + We use scaling here to follow the approach in Kaldi. + + Args: + filename: + Filename of a sound file. + Returns: + Return a 1-D tensor containing audio samples. + """ + with sf.SoundFile(filename) as sf_desc: + sampling_rate = sf_desc.samplerate + assert sampling_rate == 16000 + data = sf_desc.read(dtype=np.float32, always_2d=False) + data *= 32768 + return torch.from_numpy(data) + + +def test_fbank(): + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + wave0 = read_wave("test_data/test.wav") + wave1 = read_wave("test_data/test2.wav") + + wave0 = wave0.to(device) + wave1 = wave1.to(device) + + opts = kaldifeat.FbankOptions() + opts.frame_opts.dither = 0 + opts.device = device + + fbank = kaldifeat.Fbank(opts) + + # We can compute fbank features in batches + features = fbank([wave0, wave1]) + assert isinstance(features, list), f"{type(features)}" + assert len(features) == 2 + + # We can also compute fbank features for a single wave + features0 = fbank(wave0) + features1 = fbank(wave1) + + assert torch.allclose(features[0], features0) + assert torch.allclose(features[1], features1) + + # To compute fbank features for only a specified frame + audio_frames = fbank.convert_samples_to_frames(wave0) + feature_frame_1 = fbank.compute(audio_frames[1]) + feature_frame_10 = fbank.compute(audio_frames[10]) + + assert torch.allclose(features0[1], feature_frame_1) + assert torch.allclose(features0[10], feature_frame_10) + + +if __name__ == "__main__": + test_fbank() diff --git a/kaldifeat/python/tests/test_kaldifeat.py b/kaldifeat/python/tests/test_kaldifeat.py index c41712c..bcf14c8 100755 --- a/kaldifeat/python/tests/test_kaldifeat.py +++ b/kaldifeat/python/tests/test_kaldifeat.py @@ -52,7 +52,7 @@ def test_and_benchmark_default_parameters(): for device in devices: fbank_opts = _kaldifeat.FbankOptions() fbank_opts.frame_opts.dither = 0 - fbank_opts.set_device(device) + fbank_opts.device = device fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave().to(device) @@ -74,14 +74,14 @@ def test_use_energy_htk_compat_true(): for device in devices: fbank_opts = _kaldifeat.FbankOptions() fbank_opts.frame_opts.dither = 0 - fbank_opts.set_device(device) + fbank_opts.device = device fbank_opts.use_energy = True fbank_opts.htk_compat = True fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave().to(device) - ans = _kaldifeat.compute(data, fbank) + ans = _kaldifeat.compute_fbank_feats(data, fbank) expected = read_ark_txt("test-htk.txt") assert torch.allclose(ans.cpu(), expected, rtol=1e-2) @@ -97,12 +97,12 @@ def test_use_energy_htk_compat_false(): fbank_opts.frame_opts.dither = 0 fbank_opts.use_energy = True fbank_opts.htk_compat = False - fbank_opts.set_device(device) + fbank_opts.device = device fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave().to(device) - ans = _kaldifeat.compute(data, fbank) + ans = _kaldifeat.compute_fbank_feats(data, fbank) expected = read_ark_txt("test-with-energy.txt") assert torch.allclose(ans.cpu(), expected, rtol=1e-2) @@ -117,12 +117,12 @@ def test_40_mel(): fbank_opts = _kaldifeat.FbankOptions() fbank_opts.frame_opts.dither = 0 fbank_opts.mel_opts.num_bins = 40 - fbank_opts.set_device(device) + fbank_opts.device = device fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave().to(device) - ans = _kaldifeat.compute(data, fbank) + ans = _kaldifeat.compute_fbank_feats(data, fbank) expected = read_ark_txt("test-40.txt") assert torch.allclose(ans.cpu(), expected, rtol=1e-1) @@ -138,12 +138,12 @@ def test_40_mel_no_snip_edges(): fbank_opts.frame_opts.snip_edges = False fbank_opts.frame_opts.dither = 0 fbank_opts.mel_opts.num_bins = 40 - fbank_opts.set_device(device) + fbank_opts.device = device fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave().to(device) - ans = _kaldifeat.compute(data, fbank) + ans = _kaldifeat.compute_fbank_feats(data, fbank) expected = read_ark_txt("test-40-no-snip-edges.txt") assert torch.allclose(ans.cpu(), expected, rtol=1e-2) @@ -161,7 +161,7 @@ def test_compute_batch(): fbank_opts = _kaldifeat.FbankOptions() fbank_opts.frame_opts.dither = 0 fbank_opts.frame_opts.snip_edges = False - fbank_opts.set_device(device) + fbank_opts.device = device fbank = _kaldifeat.Fbank(fbank_opts) def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]: @@ -175,7 +175,9 @@ def test_compute_batch(): ] strided = torch.cat(strided, dim=0) - features = _kaldifeat.compute(strided, fbank).split(num_frames) + features = _kaldifeat.compute_fbank_feats(strided, fbank).split( + num_frames + ) return features diff --git a/kaldifeat/python/tests/test_options.py b/kaldifeat/python/tests/test_options.py index 41e660b..e4a01a7 100755 --- a/kaldifeat/python/tests/test_options.py +++ b/kaldifeat/python/tests/test_options.py @@ -52,6 +52,7 @@ def test_fbank_options(): opts.use_energy = False opts.use_log_fbank = True opts.use_power = True + opts.device = "cuda:0" frame_opts.blackman_coeff = 0.42 frame_opts.dither = 1 @@ -75,8 +76,8 @@ def test_fbank_options(): def main(): - # test_frame_extraction_options() - # test_mel_banks_options() + test_frame_extraction_options() + test_mel_banks_options() test_fbank_options()