from local

This commit is contained in:
dohe0342 2023-02-02 13:59:14 +09:00
parent fdaccdfaaa
commit 8bbdfc2ac1
2 changed files with 2 additions and 2 deletions

View File

@ -380,7 +380,7 @@ class Transformer(nn.Module):
return nll
class TransformerEncoder(nn.TransformerEncoder):
class TransfEncoder(nn.TransformerEncoder):
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
@ -401,7 +401,7 @@ class TransformerEncoder(nn.TransformerEncoder):
__constants__ = ['norm']
def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True):
super(TransformerEncoder, self).__init__()
super(TransfEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm