from local
This commit is contained in:
parent
fdaccdfaaa
commit
8bbdfc2ac1
Binary file not shown.
@ -380,7 +380,7 @@ class Transformer(nn.Module):
|
|||||||
return nll
|
return nll
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(nn.TransformerEncoder):
|
class TransfEncoder(nn.TransformerEncoder):
|
||||||
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
|
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
|
||||||
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
||||||
|
|
||||||
@ -401,7 +401,7 @@ class TransformerEncoder(nn.TransformerEncoder):
|
|||||||
__constants__ = ['norm']
|
__constants__ = ['norm']
|
||||||
|
|
||||||
def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True):
|
def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True):
|
||||||
super(TransformerEncoder, self).__init__()
|
super(TransfEncoder, self).__init__()
|
||||||
self.layers = _get_clones(encoder_layer, num_layers)
|
self.layers = _get_clones(encoder_layer, num_layers)
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.norm = norm
|
self.norm = norm
|
||||||
|
|||||||
Reference in New Issue
Block a user