from local

This commit is contained in:
dohe0342 2023-05-24 12:44:56 +09:00
parent 9c133b4516
commit b334bf5d0d
3 changed files with 3 additions and 3 deletions

View File

@ -165,10 +165,10 @@ class LoRAModule(nn.Module):
nn.init.zeros_(self.lora_B) nn.init.zeros_(self.lora_B)
nn.init_normal_(self.lora_A) nn.init_normal_(self.lora_A)
def forward(self, x, layer_idx=-1): def forward(self, x):
x = x.transpose(0, 1) x = x.transpose(0, 1)
x = self.lora_A[layer_idx](x) x = self.lora_A(x)
x = self.lora_B[layer_idx](x) x = self.lora_B(x)
x = x.transpose(0, 1) x = x.transpose(0, 1)
x *= self.scaling x *= self.scaling
return x return x