Remove some dead code.

This commit is contained in:
Daniel Povey 2022-03-16 18:06:01 +08:00
parent c82db4184a
commit dfc75752c4
2 changed files with 1 additions and 161 deletions

View File

@ -174,130 +174,6 @@ class VggSubsampling(nn.Module):
return x return x
class PeLUFunction(torch.autograd.Function):
"""
Computes PeLU function (PeLUFunction.apply(x, cutoff, alpha)).
The function is:
x.relu() + alpha * (cutoff - x).relu()
E.g. consider cutoff = -1, alpha = 0.01. This will tend to prevent die-off
of neurons.
"""
@staticmethod
def forward(ctx, x: Tensor, cutoff: float, alpha: float) -> Tensor:
mask1 = (x >= 0) # >=, so there is deriv if x == 0.
p = cutoff - x
mask2 = (p >= 0)
ctx.save_for_backward(mask1, mask2)
ctx.alpha = alpha
return x.relu() + alpha * p.relu()
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None]:
mask1, mask2 = ctx.saved_tensors
return mask1 * ans_grad - (ctx.alpha * mask2) * ans_grad, None, None
class PeLU(torch.nn.Module):
def __init__(self, cutoff: float = -1.0, alpha: float = 0.01) -> None:
super(PeLU, self).__init__()
self.cutoff = cutoff
self.alpha = alpha
def forward(self, x: Tensor) -> Tensor:
return PeLUFunction.apply(x, self.cutoff, self.alpha)
class ExpScale(torch.nn.Module):
def __init__(self, *shape, speed: float = 1.0, initial_scale: float = 1.0):
super(ExpScale, self).__init__()
scale = torch.tensor(initial_scale)
scale = scale.log() / speed
self.scale = nn.Parameter(scale.detach())
self.speed = speed
def forward(self, x: Tensor) -> Tensor:
return x * (self.scale * self.speed).exp()
def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor:
# double-swish, implemented/approximated as offset-swish
x = (x * torch.sigmoid(x - 1.0))
x = x * (scale * speed).exp()
return x
class SwishExpScaleFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
ctx.save_for_backward(x.detach(), scale.detach())
ctx.speed = speed
return _exp_scale_swish(x, scale, speed)
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
x, scale = ctx.saved_tensors
x.requires_grad = True
scale.requires_grad = True
with torch.enable_grad():
y = _exp_scale_swish(x, scale, ctx.speed)
y.backward(gradient=y_grad)
return x.grad, scale.grad, None
class SwishExpScale(torch.nn.Module):
# combines ExpScale and a Swish (actually the ExpScale is after the Swish).
# caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0)
#
def __init__(self, *shape, speed: float = 1.0):
super(SwishExpScale, self).__init__()
initial_log_scale = torch.zeros(()).detach()
self.scale = nn.Parameter(initial_log_scale)
self.speed = speed
def forward(self, x: Tensor) -> Tensor:
return SwishExpScaleFunction.apply(x, self.scale, self.speed)
# x = (x * torch.sigmoid(x))
# x = (x * torch.sigmoid(x))
# x = x * (self.scale * self.speed).exp()
# return x
def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor:
return (x * (scale * speed).exp()).relu()
class ExpScaleReluFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
ctx.save_for_backward(x.detach(), scale.detach())
ctx.speed = speed
return _exp_scale_relu(x, scale, speed)
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
x, scale = ctx.saved_tensors
x.requires_grad = True
scale.requires_grad = True
with torch.enable_grad():
y = _exp_scale_relu(x, scale, ctx.speed)
y.backward(gradient=y_grad)
return x.grad, scale.grad, None
class ExpScaleRelu(torch.nn.Module):
# combines ExpScale and Relu.
# caution: need to specify name for speed, e.g. ExpScaleRelu(50, speed=4.0)
def __init__(self, *shape, speed: float = 1.0):
super(ExpScaleRelu, self).__init__()
self.scale = nn.Parameter(torch.zeros(*shape))
self.speed = speed
def forward(self, x: Tensor) -> Tensor:
return ExpScaleReluFunction.apply(x, self.scale, self.speed)
# return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp()
# return x * (self.scale * self.speed).exp()
@ -639,40 +515,6 @@ class DoubleSwish(torch.nn.Module):
""" """
return DoubleSwishFunction.apply(x) return DoubleSwishFunction.apply(x)
def _test_exp_scale_swish():
x1 = torch.randn(50, 60).detach()
x2 = x1.detach()
m1 = SwishExpScale(50, 1, speed=4.0)
m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0))
x1.requires_grad = True
x2.requires_grad = True
y1 = m1(x1)
y2 = m2(x2)
assert torch.allclose(y1, y2, atol=1e-05)
y1.sum().backward()
y2.sum().backward()
assert torch.allclose(x1.grad, x2.grad, atol=1e-05)
def _test_exp_scale_relu():
x1 = torch.randn(50, 60).detach()
x2 = x1.detach()
m1 = ExpScaleRelu(50, 1, speed=4.0)
m2 = torch.nn.Sequential(nn.ReLU(), ExpScale(50, 1, speed=4.0))
x1.requires_grad = True
x2.requires_grad = True
y1 = m1(x1)
y2 = m2(x2)
assert torch.allclose(y1, y2)
y1.sum().backward()
y2.sum().backward()
assert torch.allclose(x1.grad, x2.grad)
def _test_deriv_balancer_sign(): def _test_deriv_balancer_sign():
@ -737,6 +579,4 @@ def _test_basic_norm():
if __name__ == '__main__': if __name__ == '__main__':
_test_deriv_balancer_sign() _test_deriv_balancer_sign()
_test_deriv_balancer_magnitude() _test_deriv_balancer_magnitude()
_test_exp_scale_swish()
_test_exp_scale_relu()
_test_basic_norm() _test_basic_norm()

View File

@ -19,7 +19,7 @@ import copy
import math import math
import warnings import warnings
from typing import Optional, Tuple, Sequence from typing import Optional, Tuple, Sequence
from subsampling import PeLU, ExpScale, DoubleSwish, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d from subsampling import DoubleSwish, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn