Convert swish nonlinearities to ReLU

This commit is contained in:
Daniel Povey 2022-03-05 16:28:24 +08:00
parent 0cd14ae739
commit 5f2c0a09b7
3 changed files with 82 additions and 9 deletions

View File

@ -49,15 +49,13 @@ class Conv2dSubsampling(nn.Module):
),
DerivBalancer(channel_dim=1, threshold=0.05,
max_factor=0.025),
nn.ReLU(),
ExpScale(odim, 1, 1, speed=20.0),
ExpScaleRelu(odim, 1, 1, speed=20.0),
nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
DerivBalancer(channel_dim=1, threshold=0.05,
max_factor=0.025),
nn.ReLU(),
ExpScale(odim, 1, 1, speed=20.0),
ExpScaleRelu(odim, 1, 1, speed=20.0),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
self.out_norm = nn.LayerNorm(odim, elementwise_affine=False)
@ -253,6 +251,60 @@ class ExpScaleSwish(torch.nn.Module):
# return x * (self.scale * self.speed).exp()
def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor:
return (x * (scale * speed).exp()).relu()
class ExpScaleReluFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
ctx.save_for_backward(x.detach(), scale.detach())
ctx.speed = speed
return _exp_scale_swish(x, scale, speed)
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
x, scale = ctx.saved_tensors
x.requires_grad = True
scale.requires_grad = True
with torch.enable_grad():
y = _exp_scale_swish(x, scale, ctx.speed)
y.backward(gradient=y_grad)
return x.grad, scale.grad, None
class ExpScaleReluFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
ctx.save_for_backward(x.detach(), scale.detach())
ctx.speed = speed
return _exp_scale_relu(x, scale, speed)
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
x, scale = ctx.saved_tensors
x.requires_grad = True
scale.requires_grad = True
with torch.enable_grad():
y = _exp_scale_relu(x, scale, ctx.speed)
y.backward(gradient=y_grad)
return x.grad, scale.grad, None
class ExpScaleRelu(torch.nn.Module):
# combines ExpScale and Relu.
# caution: need to specify name for speed, e.g. ExpScaleRelu(50, speed=4.0)
def __init__(self, *shape, speed: float = 1.0):
super(ExpScaleRelu, self).__init__()
self.scale = nn.Parameter(torch.zeros(*shape))
self.speed = speed
def forward(self, x: Tensor) -> Tensor:
return ExpScaleReluFunction.apply(x, self.scale, self.speed)
# return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp()
# return x * (self.scale * self.speed).exp()
class DerivBalancerFunction(torch.autograd.Function):
@ -335,6 +387,23 @@ def _test_exp_scale_swish():
y2.sum().backward()
assert torch.allclose(x1.grad, x2.grad)
def _test_exp_scale_relu():
x1 = torch.randn(50, 60).detach()
x2 = x1.detach()
m1 = ExpScaleRelu(50, 1, speed=4.0)
m2 = torch.nn.Sequential(nn.ReLU(), ExpScale(50, 1, speed=4.0))
x1.requires_grad = True
x2.requires_grad = True
y1 = m1(x1)
y2 = m2(x2)
assert torch.allclose(y1, y2)
y1.sum().backward()
y2.sum().backward()
assert torch.allclose(x1.grad, x2.grad)
def _test_deriv_balancer():
@ -360,3 +429,4 @@ def _test_deriv_balancer():
if __name__ == '__main__':
_test_deriv_balancer()
_test_exp_scale_swish()
_test_exp_scale_relu()

View File

@ -19,7 +19,7 @@ import copy
import math
import warnings
from typing import Optional, Tuple, Sequence
from subsampling import PeLU, ExpScale, ExpScaleSwish, DerivBalancer
from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer
import torch
from torch import Tensor, nn
@ -158,7 +158,7 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.05,
max_factor=0.025),
ExpScaleSwish(dim_feedforward, speed=20.0),
ExpScaleRelu(dim_feedforward, speed=20.0),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
@ -167,7 +167,7 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.05,
max_factor=0.025),
ExpScaleSwish(dim_feedforward, speed=20.0),
ExpScaleRelu(dim_feedforward, speed=20.0),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
@ -877,8 +877,10 @@ class ConvolutionModule(nn.Module):
groups=channels,
bias=bias,
)
self.balancer = DerivBalancer(channel_dim=1, threshold=0.05,
max_factor=0.025)
# shape: (channels, 1), broadcasts with (batch, channel, time).
self.activation = ExpScaleSwish(channels, 1, speed=20.0)
self.activation = ExpScaleRelu(channels, 1, speed=20.0)
self.pointwise_conv2 = nn.Conv1d(
channels,
@ -910,6 +912,7 @@ class ConvolutionModule(nn.Module):
x = self.depthwise_conv(x)
# x is (batch, channels, time)
x = self.balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time)

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2",
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2relu",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved