diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/.conformer_randomcombine.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/.conformer_randomcombine.py.swp index 1423504dc..18599bded 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/.conformer_randomcombine.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/.conformer_randomcombine.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/conformer_randomcombine.py b/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/conformer_randomcombine.py index defbdcb6e..83846158a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/conformer_randomcombine.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/conformer_randomcombine.py @@ -211,7 +211,7 @@ class Conformer(EncoderInterface): ) # (T, N, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - + ''' layer_output = [x.permute(1, 0, 2) for x in layer_output] x = self.layer_norm(1/12*(self.sigmoid(self.alpha[0])*layer_output[0] + \ @@ -228,12 +228,14 @@ class Conformer(EncoderInterface): self.sigmoid(self.alpha[11])*layer_output[11] ) ) - - #x = 0 - #for enum, alpha in enumerate(self.alpha): - # x += self.sigmoid(alpha)*layer_output[enum] - - #x = self.layer_norm((1/self.group_size)*x) + ''' + layer_outputs = [x.permute(1, 0, 2) for x in layer_outputs] + + x = 0 + for enum, alpha in enumerate(self.alpha): + x += self.sigmoid(alpha*layer_outputs[(enum+1)*self.group_layer_num-1]) + + x = self.layer_norm(x/self.group_num) return x, lengths