from local

This commit is contained in:
dohe0342 2023-02-14 14:58:54 +09:00
parent 4e57c16d28
commit 00207b7f77
3 changed files with 5 additions and 2 deletions

View File

@ -186,12 +186,15 @@ class Transformer(nn.Module):
encoder_memory, memory_key_padding_mask = self.run_encoder( encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision, warmup x, supervision, warmup
) )
x = self.ctc_output(encoder_memory)
if type(encoder_memory) == tuple: if type(encoder_memory) == tuple:
(encoder_memory, layer_outputs) = encoder_memory (encoder_memory, layer_outputs) = encoder_memory
layer_outputs = [self.ctc_output(x) for x in layer_outputs] layer_outputs = [self.ctc_output(x) for x in layer_outputs]
x = self.ctc_output(encoder_memory)
return (x, layer_outputs), encoder_memory, memory_key_padding_mask return (x, layer_outputs), encoder_memory, memory_key_padding_mask
else:
return x, encoder_memory, memory_key_padding_mask
def run_encoder( def run_encoder(
self, self,