This commit is contained in:
marcoyang 2024-03-11 12:18:03 +08:00
parent 3492d9415c
commit 6806810666

View File

@ -1755,8 +1755,6 @@ class ActivationDropoutAndLinear_lora(torch.nn.Module):
initial_scale: float = 1.0,
):
super().__init__()
# create a temporary module of nn.Linear that we'll steal the
# weights and bias from
self.l = ScaledLinear_lora(
in_features=in_channels,
out_features=out_channels,
@ -1767,17 +1765,16 @@ class ActivationDropoutAndLinear_lora(torch.nn.Module):
bias=bias,
)
self.activation = activation
if activation == "SwooshL":
self.activation = SwooshL()
elif activation == "SwooshR":
self.activation = SwooshR()
else:
assert False, activation
self.dropout = Dropout3(dropout_p, dropout_shared_dim)
def forward(self, x: Tensor):
if self.activation == "SwooshL":
x = SwooshLForward(x)
elif self.activation == "SwooshR":
x = SwooshRForward(x)
else:
assert False, self.activation
return self.dropout(self.l(x))
return self.l(self.dropout(self.activation(x)))
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: