Convert swish nonlinearities to ReLU

This commit is contained in:
Daniel Povey 2022-03-05 16:28:24 +08:00
parent 0cd14ae739
commit 5f2c0a09b7
3 changed files with 82 additions and 9 deletions

View File

@ -49,15 +49,13 @@ class Conv2dSubsampling(nn.Module):
), ),
DerivBalancer(channel_dim=1, threshold=0.05, DerivBalancer(channel_dim=1, threshold=0.05,
max_factor=0.025), max_factor=0.025),
nn.ReLU(), ExpScaleRelu(odim, 1, 1, speed=20.0),
ExpScale(odim, 1, 1, speed=20.0),
nn.Conv2d( nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2 in_channels=odim, out_channels=odim, kernel_size=3, stride=2
), ),
DerivBalancer(channel_dim=1, threshold=0.05, DerivBalancer(channel_dim=1, threshold=0.05,
max_factor=0.025), max_factor=0.025),
nn.ReLU(), ExpScaleRelu(odim, 1, 1, speed=20.0),
ExpScale(odim, 1, 1, speed=20.0),
) )
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) self.out_norm = nn.LayerNorm(odim, elementwise_affine=False)
@ -253,6 +251,60 @@ class ExpScaleSwish(torch.nn.Module):
# return x * (self.scale * self.speed).exp() # return x * (self.scale * self.speed).exp()
def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor:
return (x * (scale * speed).exp()).relu()
class ExpScaleReluFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
ctx.save_for_backward(x.detach(), scale.detach())
ctx.speed = speed
return _exp_scale_swish(x, scale, speed)
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
x, scale = ctx.saved_tensors
x.requires_grad = True
scale.requires_grad = True
with torch.enable_grad():
y = _exp_scale_swish(x, scale, ctx.speed)
y.backward(gradient=y_grad)
return x.grad, scale.grad, None
class ExpScaleReluFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
ctx.save_for_backward(x.detach(), scale.detach())
ctx.speed = speed
return _exp_scale_relu(x, scale, speed)
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
x, scale = ctx.saved_tensors
x.requires_grad = True
scale.requires_grad = True
with torch.enable_grad():
y = _exp_scale_relu(x, scale, ctx.speed)
y.backward(gradient=y_grad)
return x.grad, scale.grad, None
class ExpScaleRelu(torch.nn.Module):
# combines ExpScale and Relu.
# caution: need to specify name for speed, e.g. ExpScaleRelu(50, speed=4.0)
def __init__(self, *shape, speed: float = 1.0):
super(ExpScaleRelu, self).__init__()
self.scale = nn.Parameter(torch.zeros(*shape))
self.speed = speed
def forward(self, x: Tensor) -> Tensor:
return ExpScaleReluFunction.apply(x, self.scale, self.speed)
# return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp()
# return x * (self.scale * self.speed).exp()
class DerivBalancerFunction(torch.autograd.Function): class DerivBalancerFunction(torch.autograd.Function):
@ -335,6 +387,23 @@ def _test_exp_scale_swish():
y2.sum().backward() y2.sum().backward()
assert torch.allclose(x1.grad, x2.grad) assert torch.allclose(x1.grad, x2.grad)
def _test_exp_scale_relu():
x1 = torch.randn(50, 60).detach()
x2 = x1.detach()
m1 = ExpScaleRelu(50, 1, speed=4.0)
m2 = torch.nn.Sequential(nn.ReLU(), ExpScale(50, 1, speed=4.0))
x1.requires_grad = True
x2.requires_grad = True
y1 = m1(x1)
y2 = m2(x2)
assert torch.allclose(y1, y2)
y1.sum().backward()
y2.sum().backward()
assert torch.allclose(x1.grad, x2.grad)
def _test_deriv_balancer(): def _test_deriv_balancer():
@ -360,3 +429,4 @@ def _test_deriv_balancer():
if __name__ == '__main__': if __name__ == '__main__':
_test_deriv_balancer() _test_deriv_balancer()
_test_exp_scale_swish() _test_exp_scale_swish()
_test_exp_scale_relu()

View File

@ -19,7 +19,7 @@ import copy
import math import math
import warnings import warnings
from typing import Optional, Tuple, Sequence from typing import Optional, Tuple, Sequence
from subsampling import PeLU, ExpScale, ExpScaleSwish, DerivBalancer from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -158,7 +158,7 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.05, DerivBalancer(channel_dim=-1, threshold=0.05,
max_factor=0.025), max_factor=0.025),
ExpScaleSwish(dim_feedforward, speed=20.0), ExpScaleRelu(dim_feedforward, speed=20.0),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )
@ -167,7 +167,7 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.05, DerivBalancer(channel_dim=-1, threshold=0.05,
max_factor=0.025), max_factor=0.025),
ExpScaleSwish(dim_feedforward, speed=20.0), ExpScaleRelu(dim_feedforward, speed=20.0),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )
@ -877,8 +877,10 @@ class ConvolutionModule(nn.Module):
groups=channels, groups=channels,
bias=bias, bias=bias,
) )
self.balancer = DerivBalancer(channel_dim=1, threshold=0.05,
max_factor=0.025)
# shape: (channels, 1), broadcasts with (batch, channel, time). # shape: (channels, 1), broadcasts with (batch, channel, time).
self.activation = ExpScaleSwish(channels, 1, speed=20.0) self.activation = ExpScaleRelu(channels, 1, speed=20.0)
self.pointwise_conv2 = nn.Conv1d( self.pointwise_conv2 = nn.Conv1d(
channels, channels,
@ -910,6 +912,7 @@ class ConvolutionModule(nn.Module):
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
# x is (batch, channels, time) # x is (batch, channels, time)
x = self.balancer(x)
x = self.activation(x) x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2", default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2relu",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved