mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
bug fix
This commit is contained in:
parent
3492d9415c
commit
6806810666
@ -1755,8 +1755,6 @@ class ActivationDropoutAndLinear_lora(torch.nn.Module):
|
|||||||
initial_scale: float = 1.0,
|
initial_scale: float = 1.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# create a temporary module of nn.Linear that we'll steal the
|
|
||||||
# weights and bias from
|
|
||||||
self.l = ScaledLinear_lora(
|
self.l = ScaledLinear_lora(
|
||||||
in_features=in_channels,
|
in_features=in_channels,
|
||||||
out_features=out_channels,
|
out_features=out_channels,
|
||||||
@ -1767,17 +1765,16 @@ class ActivationDropoutAndLinear_lora(torch.nn.Module):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.activation = activation
|
if activation == "SwooshL":
|
||||||
|
self.activation = SwooshL()
|
||||||
|
elif activation == "SwooshR":
|
||||||
|
self.activation = SwooshR()
|
||||||
|
else:
|
||||||
|
assert False, activation
|
||||||
self.dropout = Dropout3(dropout_p, dropout_shared_dim)
|
self.dropout = Dropout3(dropout_p, dropout_shared_dim)
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
if self.activation == "SwooshL":
|
return self.l(self.dropout(self.activation(x)))
|
||||||
x = SwooshLForward(x)
|
|
||||||
elif self.activation == "SwooshR":
|
|
||||||
x = SwooshRForward(x)
|
|
||||||
else:
|
|
||||||
assert False, self.activation
|
|
||||||
return self.dropout(self.l(x))
|
|
||||||
|
|
||||||
|
|
||||||
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
|
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user