from local

This commit is contained in:
dohe0342 2023-02-02 13:48:52 +09:00
parent f8af9f0400
commit c875d7c1c1
2 changed files with 124 additions and 0 deletions

View File

@ -381,6 +381,130 @@ class Transformer(nn.Module):
return nll return nll
class TransformerEncoder(Module):
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
Args:
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
enable_nested_tensor: if True, input will automatically convert to nested tensor
(and convert back on output). This will improve the overall performance of
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ['norm']
def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.enable_nested_tensor = enable_nested_tensor
self.mask_check = mask_check
[docs] def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported")
output = src
convert_to_nested = False
first_layer = self.layers[0]
src_key_padding_mask_for_layers = src_key_padding_mask
why_not_sparsity_fast_path = ''
str_first_layer = "self.layers[0]"
if not isinstance(first_layer, torch.nn.TransformerEncoderLayer):
why_not_sparsity_fast_path = f"{str_first_layer} was not TransformerEncoderLayer"
elif first_layer.norm_first :
why_not_sparsity_fast_path = f"{str_first_layer}.norm_first was True"
elif first_layer.training:
why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
elif not first_layer.self_attn.batch_first:
why_not_sparsity_fast_path = f" {str_first_layer}.self_attn.batch_first was not True"
elif not first_layer.self_attn._qkv_same_embed_dim:
why_not_sparsity_fast_path = f"{str_first_layer}.self_attn._qkv_same_embed_dim was not True"
elif not first_layer.activation_relu_or_gelu:
why_not_sparsity_fast_path = f" {str_first_layer}.activation_relu_or_gelu was not True"
elif not (first_layer.norm1.eps == first_layer.norm2.eps) :
why_not_sparsity_fast_path = f"{str_first_layer}.norm1.eps was not equal to {str_first_layer}.norm2.eps"
elif not src.dim() == 3:
why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
elif not self.enable_nested_tensor:
why_not_sparsity_fast_path = "enable_nested_tensor was not True"
elif src_key_padding_mask is None:
why_not_sparsity_fast_path = "src_key_padding_mask was None"
elif (((not hasattr(self, "mask_check")) or self.mask_check)
and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
elif output.is_nested:
why_not_sparsity_fast_path = "NestedTensor input is not supported"
elif mask is not None:
why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied"
elif first_layer.self_attn.num_heads % 2 == 1:
why_not_sparsity_fast_path = "num_head is odd"
elif torch.is_autocast_enabled():
why_not_sparsity_fast_path = "autocast is enabled"
if not why_not_sparsity_fast_path:
tensor_args = (
src,
first_layer.self_attn.in_proj_weight,
first_layer.self_attn.in_proj_bias,
first_layer.self_attn.out_proj.weight,
first_layer.self_attn.out_proj.bias,
first_layer.norm1.weight,
first_layer.norm1.bias,
first_layer.norm2.weight,
first_layer.norm2.bias,
first_layer.linear1.weight,
first_layer.linear1.bias,
first_layer.linear2.weight,
first_layer.linear2.bias,
)
if torch.overrides.has_torch_function(tensor_args):
why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
elif not (src.is_cuda or 'cpu' in str(src.device)):
why_not_sparsity_fast_path = "src is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
convert_to_nested = True
output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
src_key_padding_mask_for_layers = None
for mod in self.layers:
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask_for_layers)
if convert_to_nested:
output = output.to_padded_tensor(0.)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerEncoderLayer(nn.Module): class TransformerEncoderLayer(nn.Module):
""" """
Modified from torch.nn.TransformerEncoderLayer. Modified from torch.nn.TransformerEncoderLayer.