Introduce in_scale=0.5 for SwishExpScale

This commit is contained in:
Daniel Povey 2022-03-11 19:05:55 +08:00
parent a0d5e2932c
commit 98156711ef
3 changed files with 15 additions and 10 deletions

View File

@ -221,18 +221,21 @@ class ExpScale(torch.nn.Module):
def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float, in_scale: float) -> Tensor:
# double-swish, implemented/approximated as offset-swish # double-swish, implemented/approximated as offset-swish
if in_scale != 1.0:
x = x * in_scale
x = (x * torch.sigmoid(x - 1.0)) x = (x * torch.sigmoid(x - 1.0))
x = x * (scale * speed).exp() x = x * (scale * speed).exp()
return x return x
class SwishExpScaleFunction(torch.autograd.Function): class SwishExpScaleFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: def forward(ctx, x: Tensor, scale: Tensor, speed: float, in_scale: float) -> Tensor:
ctx.save_for_backward(x.detach(), scale.detach()) ctx.save_for_backward(x.detach(), scale.detach())
ctx.speed = speed ctx.speed = speed
return _exp_scale_swish(x, scale, speed) ctx.in_scale = in_scale
return _exp_scale_swish(x, scale, speed, in_scale)
@staticmethod @staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor: def backward(ctx, y_grad: Tensor) -> Tensor:
@ -240,21 +243,23 @@ class SwishExpScaleFunction(torch.autograd.Function):
x.requires_grad = True x.requires_grad = True
scale.requires_grad = True scale.requires_grad = True
with torch.enable_grad(): with torch.enable_grad():
y = _exp_scale_swish(x, scale, ctx.speed) y = _exp_scale_swish(x, scale, ctx.speed, ctx.in_scale)
y.backward(gradient=y_grad) y.backward(gradient=y_grad)
return x.grad, scale.grad, None return x.grad, scale.grad, None, None
class SwishExpScale(torch.nn.Module): class SwishExpScale(torch.nn.Module):
# combines ExpScale and a Swish (actually the ExpScale is after the Swish). # combines ExpScale and a Swish (actually the ExpScale is after the Swish).
# caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0) # caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0)
def __init__(self, *shape, speed: float = 1.0): #
def __init__(self, *shape, speed: float = 1.0, in_scale: float = 1.0):
super(SwishExpScale, self).__init__() super(SwishExpScale, self).__init__()
self.in_scale = in_scale
self.scale = nn.Parameter(torch.zeros(*shape)) self.scale = nn.Parameter(torch.zeros(*shape))
self.speed = speed self.speed = speed
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return SwishExpScaleFunction.apply(x, self.scale, self.speed) return SwishExpScaleFunction.apply(x, self.scale, self.speed, self.in_scale)
# x = (x * torch.sigmoid(x)) # x = (x * torch.sigmoid(x))
# x = (x * torch.sigmoid(x)) # x = (x * torch.sigmoid(x))
# x = x * (self.scale * self.speed).exp() # x = x * (self.scale * self.speed).exp()

View File

@ -160,7 +160,7 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.05, DerivBalancer(channel_dim=-1, threshold=0.05,
max_factor=0.01), max_factor=0.01),
SwishExpScale(dim_feedforward, speed=20.0), SwishExpScale(dim_feedforward, speed=20.0, in_scale=0.5),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )
@ -169,7 +169,7 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.05, DerivBalancer(channel_dim=-1, threshold=0.05,
max_factor=0.01), max_factor=0.01),
SwishExpScale(dim_feedforward, speed=20.0), SwishExpScale(dim_feedforward, speed=20.0, in_scale=0.5),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/randcombine1_expscale3_rework", default="transducer_stateless/randcombine1_expscale3_rework_0.5",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved