From 6433d0e8014356fede81c9b4bfeb291c6cca8446 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Sat, 28 Jan 2023 02:37:56 +0900 Subject: [PATCH] from local --- .../.conformer_randomcombine.py.swp | Bin 98304 -> 98304 bytes .../conformer_randomcombine.py | 45 +----------------- 2 files changed, 1 insertion(+), 44 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/.conformer_randomcombine.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/.conformer_randomcombine.py.swp index a3f55c1e9f271cd0f05f4b3d11cb31d9ff7e3e47..ce2b4814b3030528bfd8d6b5cc37fd22ae7b7c46 100644 GIT binary patch delta 974 zcmaLVUr1A77{~GFopa4>^QafLNIjXYMp9xFt^ZCDLxeJMQqW=QMgI`j#C8*B5+p=j z#7m9{l7gVf{xQnxN)SYX6e|e4>7wka@%Hp4rrg=L7$=YQ%ninhEc2^%V?j~nGm={TR!%uOS)(>InUa%t^yYfu*+8L-di zi!>pO96X}xboAjI8c~5se%S{T@4OtkCiLBi=`~lDo;lr|8EvX}ANo=tvIyt=8G30s_Y1g>9&{jv zQh4#fFOo!qKT~_m?N2`|R9-z4^Q7O}sxm!nt5#iHSEReGlJu0Vii@&Ob*@Qs`$03I zEi35MHu+cA%%qN4mCpACOaGFWEy|TeYh=+UuaWj7{~GFX#Z(jo7QQRi5{`6wbiyKZBtW`CPRughz?vFYin9(t#PT%AA2D< zlkKL?ff`;VtZd#Wh*(f5&Iw~>6hWB681!O_c%y~a>Dt& zCwZUe>eTVLAXc7*WNa^*-FINBMzV7z9w5{qB*BWiHiu2k%k#Qv( z%M+T&`32^sTUDihGwe{^kD6~T3j^#6~D3I_EunLXv=Y-L1Smh~sJVx`X%s*0KA`4jV#@|hLy zoVgB_tWQP;hlV5jYm)=f0}cJr*unm8!{}XK7wtbfkO&Wt9Ept_N#;I!JFYnt;f;3T zCTqBl_V#M>>VYwbC+<=#9O1JTFTMkiQ$o0`0&V( zzGaJUc-rM{Th8ZsskOPgwQBZ{i(R){OXjHCttE5RuGW&x*>CeJnb|P;lI<}Xy-q4- znAM*BC!3>U-uCyemvz$SUoY#7&H48-zZo+pJU?TtMq`VU88aU*)fN1(l3s3Uf79)( z0_tqQJ8?PKUEm8X5g9{njbw$(A1^9Ie#TyGKt0yr0oUEf-CSlOu&Z!M=H+MEtCd=b z>IocvXE};>hSa^T>3(0S}W{PJHu7#m%Xn404Jw`EdT%j diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/conformer_randomcombine.py b/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/conformer_randomcombine.py index 85e25cda3..8c4e5f94b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/conformer_randomcombine.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/conformer_randomcombine.py @@ -213,56 +213,13 @@ class Conformer(EncoderInterface): x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) layer_outputs = [x.permute(1, 0, 2) for x in layer_outputs] - ''' - if self.group_num == 4: - x = self.layer_norm(1/4*(self.sigmoid(self.alpha[0])*layer_outputs[2] + \ - self.sigmoid(self.alpha[1])*layer_outputs[5] + \ - self.sigmoid(self.alpha[2])*layer_outputs[8] + \ - self.sigmoid(self.alpha[3])*layer_outputs[11] - ) - ) - elif self.group_num == 6: - x = self.layer_norm(1/6*(self.sigmoid(self.alpha[0])*layer_outputs[1] + \ - self.sigmoid(self.alpha[1])*layer_outputs[3] + \ - self.sigmoid(self.alpha[2])*layer_outputs[5] + \ - self.sigmoid(self.alpha[3])*layer_outputs[7] + \ - self.sigmoid(self.alpha[4])*layer_outputs[9] + \ - self.sigmoid(self.alpha[5])*layer_outputs[11] - ) - ) - - elif self.group_num == 12: - x = self.layer_norm(1/12*(self.sigmoid(self.alpha[0])*layer_outputs[0] + \ - self.sigmoid(self.alpha[1])*layer_outputs[1] + \ - self.sigmoid(self.alpha[2])*layer_outputs[2] + \ - self.sigmoid(self.alpha[3])*layer_outputs[3] + \ - self.sigmoid(self.alpha[4])*layer_outputs[4] + \ - self.sigmoid(self.alpha[5])*layer_outputs[5] + \ - self.sigmoid(self.alpha[6])*layer_outputs[6] + \ - self.sigmoid(self.alpha[7])*layer_outputs[7] + \ - self.sigmoid(self.alpha[8])*layer_outputs[8] + \ - self.sigmoid(self.alpha[9])*layer_outputs[9] + \ - self.sigmoid(self.alpha[10])*layer_outputs[10] + \ - self.sigmoid(self.alpha[11])*layer_outputs[11] - ) - ) - ''' - + if self.group_num != 0: x = 0 for enum, alpha in enumerate(self.alpha): x += self.sigmoid(alpha) * layer_outputs[(enum+1)*self.group_layer_num-1] x = self.layer_norm(x/self.group_num) - ''' - layer_outputs = [x.permute(1, 0, 2) for x in layer_outputs] - - x = 0 - for enum, alpha in enumerate(self.alpha): - x += self.sigmoid(alpha*layer_outputs[(enum+1)*self.group_layer_num-1]) - - x = self.layer_norm(x/self.group_num) - ''' return x, lengths @torch.jit.export