mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Remove some dead code.
This commit is contained in:
parent
c82db4184a
commit
dfc75752c4
@ -174,130 +174,6 @@ class VggSubsampling(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class PeLUFunction(torch.autograd.Function):
|
|
||||||
"""
|
|
||||||
Computes PeLU function (PeLUFunction.apply(x, cutoff, alpha)).
|
|
||||||
The function is:
|
|
||||||
x.relu() + alpha * (cutoff - x).relu()
|
|
||||||
E.g. consider cutoff = -1, alpha = 0.01. This will tend to prevent die-off
|
|
||||||
of neurons.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x: Tensor, cutoff: float, alpha: float) -> Tensor:
|
|
||||||
mask1 = (x >= 0) # >=, so there is deriv if x == 0.
|
|
||||||
p = cutoff - x
|
|
||||||
mask2 = (p >= 0)
|
|
||||||
ctx.save_for_backward(mask1, mask2)
|
|
||||||
ctx.alpha = alpha
|
|
||||||
return x.relu() + alpha * p.relu()
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None]:
|
|
||||||
mask1, mask2 = ctx.saved_tensors
|
|
||||||
return mask1 * ans_grad - (ctx.alpha * mask2) * ans_grad, None, None
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PeLU(torch.nn.Module):
|
|
||||||
def __init__(self, cutoff: float = -1.0, alpha: float = 0.01) -> None:
|
|
||||||
super(PeLU, self).__init__()
|
|
||||||
self.cutoff = cutoff
|
|
||||||
self.alpha = alpha
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return PeLUFunction.apply(x, self.cutoff, self.alpha)
|
|
||||||
|
|
||||||
class ExpScale(torch.nn.Module):
|
|
||||||
def __init__(self, *shape, speed: float = 1.0, initial_scale: float = 1.0):
|
|
||||||
super(ExpScale, self).__init__()
|
|
||||||
scale = torch.tensor(initial_scale)
|
|
||||||
scale = scale.log() / speed
|
|
||||||
self.scale = nn.Parameter(scale.detach())
|
|
||||||
self.speed = speed
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return x * (self.scale * self.speed).exp()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
|
||||||
# double-swish, implemented/approximated as offset-swish
|
|
||||||
x = (x * torch.sigmoid(x - 1.0))
|
|
||||||
x = x * (scale * speed).exp()
|
|
||||||
return x
|
|
||||||
|
|
||||||
class SwishExpScaleFunction(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
|
||||||
ctx.save_for_backward(x.detach(), scale.detach())
|
|
||||||
ctx.speed = speed
|
|
||||||
return _exp_scale_swish(x, scale, speed)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
|
||||||
x, scale = ctx.saved_tensors
|
|
||||||
x.requires_grad = True
|
|
||||||
scale.requires_grad = True
|
|
||||||
with torch.enable_grad():
|
|
||||||
y = _exp_scale_swish(x, scale, ctx.speed)
|
|
||||||
y.backward(gradient=y_grad)
|
|
||||||
return x.grad, scale.grad, None
|
|
||||||
|
|
||||||
|
|
||||||
class SwishExpScale(torch.nn.Module):
|
|
||||||
# combines ExpScale and a Swish (actually the ExpScale is after the Swish).
|
|
||||||
# caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0)
|
|
||||||
#
|
|
||||||
def __init__(self, *shape, speed: float = 1.0):
|
|
||||||
super(SwishExpScale, self).__init__()
|
|
||||||
|
|
||||||
initial_log_scale = torch.zeros(()).detach()
|
|
||||||
self.scale = nn.Parameter(initial_log_scale)
|
|
||||||
self.speed = speed
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return SwishExpScaleFunction.apply(x, self.scale, self.speed)
|
|
||||||
# x = (x * torch.sigmoid(x))
|
|
||||||
# x = (x * torch.sigmoid(x))
|
|
||||||
# x = x * (self.scale * self.speed).exp()
|
|
||||||
# return x
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
|
||||||
return (x * (scale * speed).exp()).relu()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ExpScaleReluFunction(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
|
||||||
ctx.save_for_backward(x.detach(), scale.detach())
|
|
||||||
ctx.speed = speed
|
|
||||||
return _exp_scale_relu(x, scale, speed)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
|
||||||
x, scale = ctx.saved_tensors
|
|
||||||
x.requires_grad = True
|
|
||||||
scale.requires_grad = True
|
|
||||||
with torch.enable_grad():
|
|
||||||
y = _exp_scale_relu(x, scale, ctx.speed)
|
|
||||||
y.backward(gradient=y_grad)
|
|
||||||
return x.grad, scale.grad, None
|
|
||||||
|
|
||||||
class ExpScaleRelu(torch.nn.Module):
|
|
||||||
# combines ExpScale and Relu.
|
|
||||||
# caution: need to specify name for speed, e.g. ExpScaleRelu(50, speed=4.0)
|
|
||||||
def __init__(self, *shape, speed: float = 1.0):
|
|
||||||
super(ExpScaleRelu, self).__init__()
|
|
||||||
self.scale = nn.Parameter(torch.zeros(*shape))
|
|
||||||
self.speed = speed
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return ExpScaleReluFunction.apply(x, self.scale, self.speed)
|
|
||||||
# return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp()
|
|
||||||
# return x * (self.scale * self.speed).exp()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -639,40 +515,6 @@ class DoubleSwish(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
return DoubleSwishFunction.apply(x)
|
return DoubleSwishFunction.apply(x)
|
||||||
|
|
||||||
def _test_exp_scale_swish():
|
|
||||||
|
|
||||||
x1 = torch.randn(50, 60).detach()
|
|
||||||
x2 = x1.detach()
|
|
||||||
|
|
||||||
m1 = SwishExpScale(50, 1, speed=4.0)
|
|
||||||
m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0))
|
|
||||||
x1.requires_grad = True
|
|
||||||
x2.requires_grad = True
|
|
||||||
|
|
||||||
y1 = m1(x1)
|
|
||||||
y2 = m2(x2)
|
|
||||||
assert torch.allclose(y1, y2, atol=1e-05)
|
|
||||||
y1.sum().backward()
|
|
||||||
y2.sum().backward()
|
|
||||||
assert torch.allclose(x1.grad, x2.grad, atol=1e-05)
|
|
||||||
|
|
||||||
def _test_exp_scale_relu():
|
|
||||||
|
|
||||||
x1 = torch.randn(50, 60).detach()
|
|
||||||
x2 = x1.detach()
|
|
||||||
|
|
||||||
m1 = ExpScaleRelu(50, 1, speed=4.0)
|
|
||||||
m2 = torch.nn.Sequential(nn.ReLU(), ExpScale(50, 1, speed=4.0))
|
|
||||||
x1.requires_grad = True
|
|
||||||
x2.requires_grad = True
|
|
||||||
|
|
||||||
y1 = m1(x1)
|
|
||||||
y2 = m2(x2)
|
|
||||||
assert torch.allclose(y1, y2)
|
|
||||||
y1.sum().backward()
|
|
||||||
y2.sum().backward()
|
|
||||||
assert torch.allclose(x1.grad, x2.grad)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_deriv_balancer_sign():
|
def _test_deriv_balancer_sign():
|
||||||
@ -737,6 +579,4 @@ def _test_basic_norm():
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
_test_deriv_balancer_sign()
|
_test_deriv_balancer_sign()
|
||||||
_test_deriv_balancer_magnitude()
|
_test_deriv_balancer_magnitude()
|
||||||
_test_exp_scale_swish()
|
|
||||||
_test_exp_scale_relu()
|
|
||||||
_test_basic_norm()
|
_test_basic_norm()
|
||||||
|
@ -19,7 +19,7 @@ import copy
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple, Sequence
|
from typing import Optional, Tuple, Sequence
|
||||||
from subsampling import PeLU, ExpScale, DoubleSwish, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
|
from subsampling import DoubleSwish, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
Loading…
x
Reference in New Issue
Block a user