from local

This commit is contained in:
dohe0342 2023-02-14 18:40:32 +09:00
parent fab9762d31
commit 3b48f52b71
3 changed files with 1 additions and 2 deletions

View File

@ -186,14 +186,13 @@ class Transformer(nn.Module):
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision, warmup
)
x = self.ctc_output(encoder_memory)
if type(encoder_memory) == tuple:
(encoder_memory, layer_outputs) = encoder_memory
layer_outputs = [self.ctc_output(x) for x in layer_outputs]
return (x, layer_outputs), encoder_memory, memory_key_padding_mask
else:
x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask
def run_encoder(