from local

This commit is contained in:
dohe0342 2023-02-14 14:42:54 +09:00
parent 17a39eb885
commit c19081f271
2 changed files with 3 additions and 1 deletions

Binary file not shown.

View File

@ -186,9 +186,11 @@ class Transformer(nn.Module):
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision, warmup
)
if type(encoder_memory) == tuple:
(encoder_memory, layer_outputs) = encoder_memory
x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask
return (x, layer_outputs), encoder_memory, memory_key_padding_mask
def run_encoder(
self,