mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
bug fix
This commit is contained in:
parent
3492d9415c
commit
6806810666
@ -1755,8 +1755,6 @@ class ActivationDropoutAndLinear_lora(torch.nn.Module):
|
||||
initial_scale: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
# create a temporary module of nn.Linear that we'll steal the
|
||||
# weights and bias from
|
||||
self.l = ScaledLinear_lora(
|
||||
in_features=in_channels,
|
||||
out_features=out_channels,
|
||||
@ -1767,17 +1765,16 @@ class ActivationDropoutAndLinear_lora(torch.nn.Module):
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.activation = activation
|
||||
if activation == "SwooshL":
|
||||
self.activation = SwooshL()
|
||||
elif activation == "SwooshR":
|
||||
self.activation = SwooshR()
|
||||
else:
|
||||
assert False, activation
|
||||
self.dropout = Dropout3(dropout_p, dropout_shared_dim)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
if self.activation == "SwooshL":
|
||||
x = SwooshLForward(x)
|
||||
elif self.activation == "SwooshR":
|
||||
x = SwooshRForward(x)
|
||||
else:
|
||||
assert False, self.activation
|
||||
return self.dropout(self.l(x))
|
||||
return self.l(self.dropout(self.activation(x)))
|
||||
|
||||
|
||||
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
|
||||
|
Loading…
x
Reference in New Issue
Block a user