from local

This commit is contained in:
dohe0342 2023-01-09 11:17:23 +09:00
parent 6a5099d9d2
commit 9924acb5a6
2 changed files with 9 additions and 7 deletions

View File

@ -211,7 +211,7 @@ class Conformer(EncoderInterface):
) # (T, N, C) ) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
'''
layer_output = [x.permute(1, 0, 2) for x in layer_output] layer_output = [x.permute(1, 0, 2) for x in layer_output]
x = self.layer_norm(1/12*(self.sigmoid(self.alpha[0])*layer_output[0] + \ x = self.layer_norm(1/12*(self.sigmoid(self.alpha[0])*layer_output[0] + \
@ -228,12 +228,14 @@ class Conformer(EncoderInterface):
self.sigmoid(self.alpha[11])*layer_output[11] self.sigmoid(self.alpha[11])*layer_output[11]
) )
) )
'''
#x = 0 layer_outputs = [x.permute(1, 0, 2) for x in layer_outputs]
#for enum, alpha in enumerate(self.alpha):
# x += self.sigmoid(alpha)*layer_output[enum] x = 0
for enum, alpha in enumerate(self.alpha):
#x = self.layer_norm((1/self.group_size)*x) x += self.sigmoid(alpha*layer_outputs[(enum+1)*self.group_layer_num-1])
x = self.layer_norm(x/self.group_num)
return x, lengths return x, lengths