from local

This commit is contained in:
dohe0342 2023-02-02 14:07:53 +09:00
parent 0f3696ed91
commit 89271fb9ea
2 changed files with 3 additions and 1 deletions

View File

@ -490,9 +490,11 @@ class TransfEncoder(nn.TransformerEncoder):
convert_to_nested = True
output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
src_key_padding_mask_for_layers = None
outputs = []
for mod in self.layers:
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask_for_layers)
outputs.append(output)
if convert_to_nested:
output = output.to_padded_tensor(0.)