This commit is contained in:
marcoyang 2024-03-11 12:18:03 +08:00
parent 3492d9415c
commit 6806810666

View File

@ -1755,8 +1755,6 @@ class ActivationDropoutAndLinear_lora(torch.nn.Module):
initial_scale: float = 1.0, initial_scale: float = 1.0,
): ):
super().__init__() super().__init__()
# create a temporary module of nn.Linear that we'll steal the
# weights and bias from
self.l = ScaledLinear_lora( self.l = ScaledLinear_lora(
in_features=in_channels, in_features=in_channels,
out_features=out_channels, out_features=out_channels,
@ -1767,17 +1765,16 @@ class ActivationDropoutAndLinear_lora(torch.nn.Module):
bias=bias, bias=bias,
) )
self.activation = activation if activation == "SwooshL":
self.activation = SwooshL()
elif activation == "SwooshR":
self.activation = SwooshR()
else:
assert False, activation
self.dropout = Dropout3(dropout_p, dropout_shared_dim) self.dropout = Dropout3(dropout_p, dropout_shared_dim)
def forward(self, x: Tensor): def forward(self, x: Tensor):
if self.activation == "SwooshL": return self.l(self.dropout(self.activation(x)))
x = SwooshLForward(x)
elif self.activation == "SwooshR":
x = SwooshRForward(x)
else:
assert False, self.activation
return self.dropout(self.l(x))
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: