Remove dead code

This commit is contained in:
Daniel Povey 2022-03-16 12:49:00 +08:00
parent a783b96467
commit 00be56c7a0

View File

@ -876,8 +876,7 @@ class ConvolutionModule(nn.Module):
self.deriv_balancer2 = DerivBalancer(channel_dim=1, self.deriv_balancer2 = DerivBalancer(channel_dim=1,
min_positive=0.05, max_positive=1.0) min_positive=0.05, max_positive=1.0)
# Shape: (channels, 1), broadcasts with (batch, channel, time). self.activation = DoubleSwish()
self.activation = SwishOffset()
self.pointwise_conv2 = ScaledConv1d( self.pointwise_conv2 = ScaledConv1d(
channels, channels,
@ -918,24 +917,6 @@ class ConvolutionModule(nn.Module):
return x.permute(2, 0, 1) return x.permute(2, 0, 1)
class Swish(torch.nn.Module):
"""Construct an Swish object."""
def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function."""
return x * torch.sigmoid(x)
class SwishOffset(torch.nn.Module):
"""Construct an SwishOffset object."""
def __init__(self, offset: float = -1.0) -> None:
super(SwishOffset, self).__init__()
self.offset = offset
def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function."""
return x * torch.sigmoid(x + self.offset)
class Identity(torch.nn.Module): class Identity(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x return x