from local

This commit is contained in:
dohe0342 2023-02-02 14:09:13 +09:00
parent 8a36b40a12
commit 8031ade3b5
2 changed files with 3 additions and 2 deletions

View File

@ -113,6 +113,7 @@ class Transformer(nn.Module):
# num_layers=num_encoder_layers, # num_layers=num_encoder_layers,
# norm=encoder_norm, # norm=encoder_norm,
#) #)
print(encoder_norm)
self.encoder = TransfEncoder( self.encoder = TransfEncoder(
encoder_layer=encoder_layer, encoder_layer=encoder_layer,
num_layers=num_encoder_layers, num_layers=num_encoder_layers,
@ -491,10 +492,10 @@ class TransfEncoder(nn.TransformerEncoder):
output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False) output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
src_key_padding_mask_for_layers = None src_key_padding_mask_for_layers = None
outputs = [] layer_outputs = []
for mod in self.layers: for mod in self.layers:
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask_for_layers) output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask_for_layers)
outputs.append(output) layer_outputs.append(output)
if convert_to_nested: if convert_to_nested:
output = output.to_padded_tensor(0.) output = output.to_padded_tensor(0.)