from local

This commit is contained in:
dohe0342 2023-05-24 12:03:42 +09:00
parent 64a4a7d6d7
commit 86ad1ce0eb
3 changed files with 6 additions and 5 deletions

View File

@ -138,7 +138,6 @@ class LoRAModule(nn.Module):
def __init__(
self,
embedding_dim: float = 768,
layer_num: int = 12,
rank: int = 16,
lora_alpha: int = 1,
lora_dropout: float = 0.1,
@ -153,10 +152,12 @@ class LoRAModule(nn.Module):
else:
self.lora_dropout = lambda x: x
self.lora_A = nn.ModuleList(
[nn.Linear(embedding_dim, self.r) for _ in range(layer_num)])
self.lora_B = nn.ModuleList(
[nn.Linear(self.r, embedding_dim) for _ in range(layer_num)])
#self.lora_A = nn.ModuleList(
# [nn.Linear(embedding_dim, self.r) for _ in range(layer_num)])
#self.lora_B = nn.ModuleList(
# [nn.Linear(self.r, embedding_dim) for _ in range(layer_num)])
self.lora_A = nn.Linear(embedding_dim, self.r)
self.lora_B = nn.Linear(self.r, embedding_dim)
self.scaling = self.lora_alpha / self.r
self.reset_parameters()