mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Remove ExpScale in feedforward layes.
This commit is contained in:
parent
97c0bb82d3
commit
f351777e9c
@ -565,8 +565,13 @@ class DerivBalancer(torch.nn.Module):
|
|||||||
|
|
||||||
class DoubleSwish(torch.nn.Module):
|
class DoubleSwish(torch.nn.Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Return Swich activation function."""
|
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||||
return x * torch.sigmoid(x - 1.0)
|
that we approximate closely with x * sigmoid(x-1), expressed for more memory-efficient
|
||||||
|
backprop as (x-1) * torch.sigmoid(x - 1) + torch.sigmoid(x - 1)
|
||||||
|
"""
|
||||||
|
x1 = x - 1.0
|
||||||
|
s = torch.sigmoid(x1)
|
||||||
|
return (x1 * s) + s # (x-1) * s + s == x * s
|
||||||
|
|
||||||
|
|
||||||
def _test_exp_scale_swish():
|
def _test_exp_scale_swish():
|
||||||
@ -581,10 +586,10 @@ def _test_exp_scale_swish():
|
|||||||
|
|
||||||
y1 = m1(x1)
|
y1 = m1(x1)
|
||||||
y2 = m2(x2)
|
y2 = m2(x2)
|
||||||
assert torch.allclose(y1, y2)
|
assert torch.allclose(y1, y2, atol=1e-05)
|
||||||
y1.sum().backward()
|
y1.sum().backward()
|
||||||
y2.sum().backward()
|
y2.sum().backward()
|
||||||
assert torch.allclose(x1.grad, x2.grad)
|
assert torch.allclose(x1.grad, x2.grad, atol=1e-05)
|
||||||
|
|
||||||
def _test_exp_scale_relu():
|
def _test_exp_scale_relu():
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ import copy
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple, Sequence
|
from typing import Optional, Tuple, Sequence
|
||||||
from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
|
from subsampling import PeLU, ExpScale, DoubleSwish, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -159,7 +159,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
ScaledLinear(d_model, dim_feedforward),
|
ScaledLinear(d_model, dim_feedforward),
|
||||||
DerivBalancer(channel_dim=-1),
|
DerivBalancer(channel_dim=-1),
|
||||||
SwishExpScale(dim_feedforward, speed=20.0),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||||
)
|
)
|
||||||
@ -167,7 +167,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
ScaledLinear(d_model, dim_feedforward),
|
ScaledLinear(d_model, dim_feedforward),
|
||||||
DerivBalancer(channel_dim=-1),
|
DerivBalancer(channel_dim=-1),
|
||||||
SwishExpScale(dim_feedforward, speed=20.0),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||||
)
|
)
|
||||||
|
@ -110,7 +110,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95",
|
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
Loading…
x
Reference in New Issue
Block a user