from local

This commit is contained in:
dohe0342 2023-01-09 19:30:29 +09:00
parent 1994866c0a
commit dc641991a5
2 changed files with 5 additions and 1 deletions

View File

@ -243,7 +243,11 @@ class Interformer(nn.Module):
x: torch.Tensor,
x_lens: torch.Tensor,
):
encoder_out, x_lens = self.pt_encoder(x, x_lens, warmup=warmup)
encoder_out, x_lens, layer_outputs = self.pt_encoder(x,
x_lens,
warmup=warmup,
get_layer_output=True
)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network