from local

This commit is contained in:
dohe0342 2023-01-03 13:41:44 +09:00
parent b21111b67f
commit 0207b96426
2 changed files with 4 additions and 2 deletions

View File

@ -145,9 +145,11 @@ class ResidualAdapterModule(nn.Module):
super().__init__()
self.adapters = ConvolutionModule(768, 31)
def build_adapter(embedding_dim, proj_dim, type_='conv'):
return nn.Sequential(
if type_ == 'conv':
return ConvolutionModule(768, 31)
else:
return nn.Sequential(
#nn.LayerNorm(embedding_dim),
nn.Linear(embedding_dim, proj_dim),
nn.ReLU(),