from local

This commit is contained in:
dohe0342 2023-02-02 13:58:45 +09:00
parent ffde762b9d
commit 37dc394116
2 changed files with 1 additions and 3 deletions

View File

@ -23,8 +23,6 @@ import torch.nn as nn
from label_smoothing import LabelSmoothingLoss
from subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence
from torch.nn.modules import Module
from torch import Tensor
# Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor]
@ -382,7 +380,7 @@ class Transformer(nn.Module):
return nll
class TransformerEncoder(Module):
class TransformerEncoder(nn.TransformerEncoder):
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.