pelu_base->expscale, add 2xExpScale in subsampling, and in feedforward units.

This commit is contained in:
Daniel Povey 2022-03-03 23:52:03 +08:00
parent 3fb559d2f0
commit 5c177fc52b
3 changed files with 15 additions and 2 deletions

View File

@ -48,10 +48,12 @@ class Conv2dSubsampling(nn.Module):
in_channels=1, out_channels=odim, kernel_size=3, stride=2 in_channels=1, out_channels=odim, kernel_size=3, stride=2
), ),
nn.ReLU(), nn.ReLU(),
ExpScale(odim, 1, 1, speed=2.0),
nn.Conv2d( nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2 in_channels=odim, out_channels=odim, kernel_size=3, stride=2
), ),
nn.ReLU(), nn.ReLU(),
ExpScale(odim, 1, 1, speed=2.0),
) )
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) self.out_norm = nn.LayerNorm(odim, elementwise_affine=False)
@ -195,3 +197,12 @@ class PeLU(torch.nn.Module):
self.alpha = alpha self.alpha = alpha
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return PeLUFunction.apply(x, self.cutoff, self.alpha) return PeLUFunction.apply(x, self.cutoff, self.alpha)
class ExpScale(torch.nn.Module):
def __init__(self, *shape, speed: float = 1.0):
super(ExpScale, self).__init__()
self.scale = nn.Parameter(torch.zeros(*shape))
self.speed = speed
def forward(self, x: Tensor) -> Tensor:
return x * (self.scale * self.speed).exp()

View File

@ -19,7 +19,7 @@ import copy
import math import math
import warnings import warnings
from typing import Optional, Tuple, Sequence from typing import Optional, Tuple, Sequence
from subsampling import PeLU from subsampling import PeLU, ExpScale
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -157,6 +157,7 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
Swish(), Swish(),
ExpScale(dim_feedforward, speed=2.0),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )
@ -164,6 +165,7 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward_macaron = nn.Sequential( self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
Swish(), Swish(),
ExpScale(dim_feedforward, speed=2.0),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/specaugmod_baseline_randcombine1_pelu_base", default="transducer_stateless/specaugmod_baseline_randcombine1_expscale",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved