from local

This commit is contained in:
dohe0342 2023-05-18 17:09:55 +09:00
parent 6cca73c4b4
commit 383cb553f4
2 changed files with 0 additions and 29 deletions

View File

@ -270,35 +270,6 @@ class LoRAModule(nn.Module):
x = x.transpose(0, 1)
return x
'''
self.type = 'linear'
def build_adapter(embedding_dim, proj_dim, type_=self.type):
if type_ == 'conv':
return ConvolutionModule(768, 31)
else:
return nn.Sequential(
#nn.LayerNorm(embedding_dim),
nn.Linear(embedding_dim, proj_dim),
nn.ReLU(),
nn.Linear(proj_dim, embedding_dim),
nn.LayerNorm(embedding_dim),
)
self.adapter_layers = nn.ModuleList(
[build_adapter(embedding_dim, proj_dim, type_=self.type) for _ in range(layer_num)]
)
def forward(self, x, layer_id=-1):
x = x.transpose(0, 1)
residual = x
x = self.adapter_layers[layer_id](x)
x = residual + x
x = x.transpose(0, 1)
return x
'''
class ResidualAdapterModule(nn.Module):
"""