mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
64a4a7d6d7
commit
86ad1ce0eb
Binary file not shown.
Binary file not shown.
@ -138,7 +138,6 @@ class LoRAModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: float = 768,
|
||||
layer_num: int = 12,
|
||||
rank: int = 16,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.1,
|
||||
@ -153,10 +152,12 @@ class LoRAModule(nn.Module):
|
||||
else:
|
||||
self.lora_dropout = lambda x: x
|
||||
|
||||
self.lora_A = nn.ModuleList(
|
||||
[nn.Linear(embedding_dim, self.r) for _ in range(layer_num)])
|
||||
self.lora_B = nn.ModuleList(
|
||||
[nn.Linear(self.r, embedding_dim) for _ in range(layer_num)])
|
||||
#self.lora_A = nn.ModuleList(
|
||||
# [nn.Linear(embedding_dim, self.r) for _ in range(layer_num)])
|
||||
#self.lora_B = nn.ModuleList(
|
||||
# [nn.Linear(self.r, embedding_dim) for _ in range(layer_num)])
|
||||
self.lora_A = nn.Linear(embedding_dim, self.r)
|
||||
self.lora_B = nn.Linear(self.r, embedding_dim)
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
self.reset_parameters()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user