from local

This commit is contained in:
dohe0342 2023-05-18 17:07:42 +09:00
parent d2dffc3bb9
commit aecaf1cedc
2 changed files with 12 additions and 2 deletions

View File

@ -252,9 +252,19 @@ class LoRAModule(nn.Module):
else: else:
self.lora_dropout = lambda x: x self.lora_dropout = lambda x: x
self.lora_A = nn.Linear(embedding_dim, self.r) self.lora_A = nn.ModuleList(
self.lora_B = nn.Linear(self.r, embedding_dim) [nn.Linear(embedding_dim, self.r) for _ in range(layer_num)])
self.lora_B = nn.ModuleList(
[nn.Linear(self.r, embedding_dim) for _ in range(layer_num)])
self.scaling = self.lora_alpha / self.r self.scaling = self.lora_alpha / self.r
self.reset_parameters()
def reset_parameters(self):
nn.init.zeros_(self.lora_B)
nn.init_normal_(self.lora_A)
def forward(self, x, layer_id=-1):
''' '''