mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Add pelu to this good-performing setup..
This commit is contained in:
parent
c1063def95
commit
9d1b4ae046
@ -45,13 +45,14 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||||
),
|
),
|
||||||
nn.ReLU(),
|
PeLU(cutoff=-1.0),
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||||
),
|
),
|
||||||
nn.ReLU(),
|
PeLU(cutoff=-5.0),
|
||||||
)
|
)
|
||||||
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||||
|
self.out_norm = nn.LayerNorm(odim, elementwise_affine=False)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Subsample x.
|
"""Subsample x.
|
||||||
@ -70,6 +71,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
b, c, t, f = x.size()
|
b, c, t, f = x.size()
|
||||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||||
|
x = self.out_norm(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -159,3 +161,35 @@ class VggSubsampling(nn.Module):
|
|||||||
b, c, t, f = x.size()
|
b, c, t, f = x.size()
|
||||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PeLUFunction(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Computes PeLU function (PeLUFunction.apply(x, cutoff, alpha)).
|
||||||
|
The function is:
|
||||||
|
x.relu() + alpha * (cutoff - x).relu()
|
||||||
|
E.g. consider cutoff = -1, alpha = 0.01. This will tend to prevent die-off
|
||||||
|
of neurons.
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: Tensor, cutoff: float, alpha: float) -> Tensor:
|
||||||
|
mask1 = (x >= 0) # >=, so there is deriv if x == 0.
|
||||||
|
p = cutoff - x
|
||||||
|
mask2 = (p >= 0)
|
||||||
|
ctx.save_for_backward(mask1, mask2)
|
||||||
|
ctx.alpha = alpha
|
||||||
|
return x.relu() + alpha * p.relu()
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None]:
|
||||||
|
mask1, mask2 = ctx.saved_tensors
|
||||||
|
return mask1 * ans_grad - (ctx.alpha * mask2) * ans_grad, None, None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class PeLU(torch.nn.Module):
|
||||||
|
def __init__(self, cutoff: float = -1.0, alpha: float = 0.01) -> None:
|
||||||
|
super(PeLU, self).__init__()
|
||||||
|
self.cutoff = cutoff
|
||||||
|
self.alpha = alpha
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return PeLUFunction.apply(x, self.cutoff, self.alpha)
|
||||||
|
@ -19,6 +19,7 @@ import copy
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple, Sequence
|
from typing import Optional, Tuple, Sequence
|
||||||
|
from subsampling import PeLU
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -84,12 +85,7 @@ class Conformer(Transformer):
|
|||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
|
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
|
||||||
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
|
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
|
||||||
self.normalize_before = normalize_before
|
self.normalize_before = normalize_before
|
||||||
if self.normalize_before:
|
|
||||||
self.after_norm = nn.LayerNorm(d_model) # TODO: remove.
|
|
||||||
else:
|
|
||||||
# Note: TorchScript detects that self.after_norm could be used inside forward()
|
|
||||||
# and throws an error without this change.
|
|
||||||
self.after_norm = identity
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
@ -118,9 +114,6 @@ class Conformer(Transformer):
|
|||||||
|
|
||||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C)
|
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C)
|
||||||
|
|
||||||
if self.normalize_before:
|
|
||||||
x = self.after_norm(x)
|
|
||||||
|
|
||||||
logits = self.encoder_output_layer(x)
|
logits = self.encoder_output_layer(x)
|
||||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
@ -163,14 +156,14 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
Swish(),
|
PeLU(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(dim_feedforward, d_model),
|
nn.Linear(dim_feedforward, d_model),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
Swish(),
|
PeLU(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(dim_feedforward, d_model),
|
nn.Linear(dim_feedforward, d_model),
|
||||||
)
|
)
|
||||||
@ -889,7 +882,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
padding=0,
|
padding=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
self.activation = Swish()
|
self.activation = PeLU()
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Compute convolution module.
|
"""Compute convolution module.
|
||||||
|
@ -110,7 +110,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/specaugmod_baseline_randcombine1",
|
default="transducer_stateless/specaugmod_baseline_randcombine1_pelu",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
Loading…
x
Reference in New Issue
Block a user