from local

This commit is contained in:
dohe0342 2023-01-09 11:31:27 +09:00
parent a4375ac7d4
commit 40a3e35eb7
2 changed files with 17 additions and 6 deletions

View File

@ -211,14 +211,25 @@ class Conformer(EncoderInterface):
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
layer_outputs = [x.permute(1, 0, 2) for x in layer_outputs]
x = self.layer_norm(1/4*(self.sigmoid(self.alpha[0])*layer_outputs[2] + \
self.sigmoid(self.alpha[1])*layer_outputs[5] + \
self.sigmoid(self.alpha[2])*layer_outputs[8] + \
self.sigmoid(self.alpha[3])*layer_outputs[11]
if self.group_num == 4:
x = self.layer_norm(1/4*(self.sigmoid(self.alpha[0])*layer_outputs[2] + \
self.sigmoid(self.alpha[1])*layer_outputs[5] + \
self.sigmoid(self.alpha[2])*layer_outputs[8] + \
self.sigmoid(self.alpha[3])*layer_outputs[11]
)
)
)
elif self.group_num == 6:
x = self.layer_norm(1/6*(self.sigmoid(self.alpha[0])*layer_outputs[1] + \
self.sigmoid(self.alpha[1])*layer_outputs[3] + \
self.sigmoid(self.alpha[2])*layer_outputs[5] + \
self.sigmoid(self.alpha[3])*layer_outputs[7] + \
self.sigmoid(self.alpha[4])*layer_outputs[9] + \
self.sigmoid(self.alpha[5])*layer_outputs[11]
)
)
'''
x = self.layer_norm(1/12*(self.sigmoid(self.alpha[0])*layer_output[0] + \