from local

This commit is contained in:
dohe0342 2023-05-24 12:03:42 +09:00
parent 64a4a7d6d7
commit 86ad1ce0eb
3 changed files with 6 additions and 5 deletions

View File

@ -138,7 +138,6 @@ class LoRAModule(nn.Module):
def __init__( def __init__(
self, self,
embedding_dim: float = 768, embedding_dim: float = 768,
layer_num: int = 12,
rank: int = 16, rank: int = 16,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.1, lora_dropout: float = 0.1,
@ -153,10 +152,12 @@ class LoRAModule(nn.Module):
else: else:
self.lora_dropout = lambda x: x self.lora_dropout = lambda x: x
self.lora_A = nn.ModuleList( #self.lora_A = nn.ModuleList(
[nn.Linear(embedding_dim, self.r) for _ in range(layer_num)]) # [nn.Linear(embedding_dim, self.r) for _ in range(layer_num)])
self.lora_B = nn.ModuleList( #self.lora_B = nn.ModuleList(
[nn.Linear(self.r, embedding_dim) for _ in range(layer_num)]) # [nn.Linear(self.r, embedding_dim) for _ in range(layer_num)])
self.lora_A = nn.Linear(embedding_dim, self.r)
self.lora_B = nn.Linear(self.r, embedding_dim)
self.scaling = self.lora_alpha / self.r self.scaling = self.lora_alpha / self.r
self.reset_parameters() self.reset_parameters()