mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
also add lora in SelfAttention (for the value proj)
This commit is contained in:
parent
9bc1ad87b4
commit
5272a71ec9
@ -634,9 +634,23 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
lora_dropout=lora_dropout,
|
lora_dropout=lora_dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
|
self.self_attn1 = SelfAttention(
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
value_head_dim,
|
||||||
|
lora_r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim)
|
self.self_attn2 = SelfAttention(
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
value_head_dim,
|
||||||
|
lora_r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
self.feed_forward1 = FeedforwardModule(
|
self.feed_forward1 = FeedforwardModule(
|
||||||
embed_dim, (feedforward_dim * 3) // 4, dropout
|
embed_dim, (feedforward_dim * 3) // 4, dropout
|
||||||
@ -1901,9 +1915,19 @@ class SelfAttention(nn.Module):
|
|||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
value_head_dim: int,
|
value_head_dim: int,
|
||||||
|
lora_r: int = 0,
|
||||||
|
lora_alpha: int = 4,
|
||||||
|
lora_dropout: float=0.0
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
|
self.in_proj = ScaledLinear_lora(
|
||||||
|
in_features=embed_dim,
|
||||||
|
out_features=num_heads * value_head_dim,
|
||||||
|
r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
self.out_proj = ScaledLinear(
|
self.out_proj = ScaledLinear(
|
||||||
num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
|
num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
|
||||||
|
Loading…
x
Reference in New Issue
Block a user