Remove ExpScale in feedforward layes.

This commit is contained in:
Daniel Povey 2022-03-13 17:29:39 +08:00
parent 97c0bb82d3
commit f351777e9c
3 changed files with 13 additions and 8 deletions

View File

@ -565,8 +565,13 @@ class DerivBalancer(torch.nn.Module):
class DoubleSwish(torch.nn.Module): class DoubleSwish(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function.""" """Return double-swish activation function which is an approximation to Swish(Swish(x)),
return x * torch.sigmoid(x - 1.0) that we approximate closely with x * sigmoid(x-1), expressed for more memory-efficient
backprop as (x-1) * torch.sigmoid(x - 1) + torch.sigmoid(x - 1)
"""
x1 = x - 1.0
s = torch.sigmoid(x1)
return (x1 * s) + s # (x-1) * s + s == x * s
def _test_exp_scale_swish(): def _test_exp_scale_swish():
@ -581,10 +586,10 @@ def _test_exp_scale_swish():
y1 = m1(x1) y1 = m1(x1)
y2 = m2(x2) y2 = m2(x2)
assert torch.allclose(y1, y2) assert torch.allclose(y1, y2, atol=1e-05)
y1.sum().backward() y1.sum().backward()
y2.sum().backward() y2.sum().backward()
assert torch.allclose(x1.grad, x2.grad) assert torch.allclose(x1.grad, x2.grad, atol=1e-05)
def _test_exp_scale_relu(): def _test_exp_scale_relu():

View File

@ -19,7 +19,7 @@ import copy
import math import math
import warnings import warnings
from typing import Optional, Tuple, Sequence from typing import Optional, Tuple, Sequence
from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d from subsampling import PeLU, ExpScale, DoubleSwish, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -159,7 +159,7 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
ScaledLinear(d_model, dim_feedforward), ScaledLinear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1), DerivBalancer(channel_dim=-1),
SwishExpScale(dim_feedforward, speed=20.0), DoubleSwish(),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
) )
@ -167,7 +167,7 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward_macaron = nn.Sequential( self.feed_forward_macaron = nn.Sequential(
ScaledLinear(d_model, dim_feedforward), ScaledLinear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1), DerivBalancer(channel_dim=-1),
SwishExpScale(dim_feedforward, speed=20.0), DoubleSwish(),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
) )

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95", default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved