mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 02:22:17 +00:00
pelu_base->expscale, add 2xExpScale in subsampling, and in feedforward units.
This commit is contained in:
parent
3fb559d2f0
commit
5c177fc52b
@ -48,10 +48,12 @@ class Conv2dSubsampling(nn.Module):
|
||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
nn.ReLU(),
|
||||
ExpScale(odim, 1, 1, speed=2.0),
|
||||
nn.Conv2d(
|
||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
nn.ReLU(),
|
||||
ExpScale(odim, 1, 1, speed=2.0),
|
||||
)
|
||||
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||
self.out_norm = nn.LayerNorm(odim, elementwise_affine=False)
|
||||
@ -195,3 +197,12 @@ class PeLU(torch.nn.Module):
|
||||
self.alpha = alpha
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return PeLUFunction.apply(x, self.cutoff, self.alpha)
|
||||
|
||||
class ExpScale(torch.nn.Module):
|
||||
def __init__(self, *shape, speed: float = 1.0):
|
||||
super(ExpScale, self).__init__()
|
||||
self.scale = nn.Parameter(torch.zeros(*shape))
|
||||
self.speed = speed
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x * (self.scale * self.speed).exp()
|
||||
|
@ -19,7 +19,7 @@ import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Sequence
|
||||
from subsampling import PeLU
|
||||
from subsampling import PeLU, ExpScale
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@ -157,6 +157,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
self.feed_forward = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
Swish(),
|
||||
ExpScale(dim_feedforward, speed=2.0),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim_feedforward, d_model),
|
||||
)
|
||||
@ -164,6 +165,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
self.feed_forward_macaron = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
Swish(),
|
||||
ExpScale(dim_feedforward, speed=2.0),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim_feedforward, d_model),
|
||||
)
|
||||
|
@ -110,7 +110,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="transducer_stateless/specaugmod_baseline_randcombine1_pelu_base",
|
||||
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
Loading…
x
Reference in New Issue
Block a user