add lora version of ActivationDropouAndLinear; currently a simple version

This commit is contained in:
marcoyang 2024-03-11 12:02:35 +08:00
parent bb8f6b0ef7
commit 3492d9415c
2 changed files with 44 additions and 2 deletions

View File

@ -1740,6 +1740,45 @@ class ActivationDropoutAndLinear(torch.nn.Module):
self.dropout_shared_dim,
)
class ActivationDropoutAndLinear_lora(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
bias: bool = True,
activation: str = "SwooshL",
dropout_p: FloatLike = 0.0,
dropout_shared_dim: Optional[int] = -1,
r: int=0,
lora_alpha: int=1,
lora_dropout: float=0.0,
initial_scale: float = 1.0,
):
super().__init__()
# create a temporary module of nn.Linear that we'll steal the
# weights and bias from
self.l = ScaledLinear_lora(
in_features=in_channels,
out_features=out_channels,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
initial_scale=initial_scale,
bias=bias,
)
self.activation = activation
self.dropout = Dropout3(dropout_p, dropout_shared_dim)
def forward(self, x: Tensor):
if self.activation == "SwooshL":
x = SwooshLForward(x)
elif self.activation == "SwooshR":
x = SwooshRForward(x)
else:
assert False, self.activation
return self.dropout(self.l(x))
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
if num_channels <= x.shape[-1]:

View File

@ -34,6 +34,7 @@ from scaling import (
)
from scaling import (
ActivationDropoutAndLinear,
ActivationDropoutAndLinear_lora,
Balancer,
BiasNorm,
ChunkCausalDepthwiseConv1d,
@ -2066,7 +2067,6 @@ class FeedforwardModule(nn.Module):
lora_dropout: float=0.0
):
super(FeedforwardModule, self).__init__()
# self.in_proj = nn.Linear(embed_dim, feedforward_dim)
self.in_proj = ScaledLinear_lora(
in_features=embed_dim,
out_features=feedforward_dim,
@ -2086,13 +2086,16 @@ class FeedforwardModule(nn.Module):
)
# shared_dim=0 means we share the dropout mask along the time axis
self.out_proj = ActivationDropoutAndLinear(
self.out_proj = ActivationDropoutAndLinear_lora(
feedforward_dim,
embed_dim,
activation="SwooshL",
dropout_p=dropout,
dropout_shared_dim=0,
bias=True,
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
initial_scale=0.1,
)