from local

This commit is contained in:
dohe0342 2023-05-18 17:04:00 +09:00
parent fbb9649217
commit 183bb88394
2 changed files with 16 additions and 1 deletions

View File

@ -239,10 +239,25 @@ class LoRAModule(nn.Module):
embedding_dim: float = 768,
layer_num: int = 12,
rank: int = 16,
lora_alpha: int = 1,
lora_dropout: float = 0.1,
) -> None:
super().__init__()
self.r = rank
self.lora_alpha = lora_alpha
#Optional dropout
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
self.lora_A = nn.Linear(embedding_dim, r)
self.lora_B = nn.Linear(r, embedding_dim)
self.scaling = self.lora_alpha / self.r
'''
self.type = 'linear'
def build_adapter(embedding_dim, proj_dim, type_=self.type):
@ -269,7 +284,7 @@ class LoRAModule(nn.Module):
x = x.transpose(0, 1)
return x
'''
class ResidualAdapterModule(nn.Module):
"""