from local

This commit is contained in:
dohe0342 2023-01-09 20:21:47 +09:00
parent 8b59cb1ac0
commit cc5c3eff53
3 changed files with 6 additions and 5 deletions

View File

@ -244,11 +244,12 @@ class Interformer(nn.Module):
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
): ):
encoder_out, x_lens, pt_layer_outputs = self.pt_encoder(x, with torch.no_grad():
x_lens, encoder_out, x_lens, pt_layer_outputs = self.pt_encoder(x,
warmup=warmup, x_lens,
get_layer_output=True warmup=warmup,
) get_layer_output=True
)
inter_layer_outputs = self.inter_encoder(pt_layer_outputs) inter_layer_outputs = self.inter_encoder(pt_layer_outputs)
loss = 0 loss = 0