Add pelu to this good-performing setup..

This commit is contained in:
Daniel Povey 2022-03-02 16:33:27 +08:00
parent c1063def95
commit 9d1b4ae046
3 changed files with 42 additions and 15 deletions

View File

@ -45,13 +45,14 @@ class Conv2dSubsampling(nn.Module):
nn.Conv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
PeLU(cutoff=-1.0),
nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
PeLU(cutoff=-5.0),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
self.out_norm = nn.LayerNorm(odim, elementwise_affine=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
@ -70,6 +71,7 @@ class Conv2dSubsampling(nn.Module):
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out_norm(x)
return x
@ -159,3 +161,35 @@ class VggSubsampling(nn.Module):
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
return x
class PeLUFunction(torch.autograd.Function):
"""
Computes PeLU function (PeLUFunction.apply(x, cutoff, alpha)).
The function is:
x.relu() + alpha * (cutoff - x).relu()
E.g. consider cutoff = -1, alpha = 0.01. This will tend to prevent die-off
of neurons.
"""
@staticmethod
def forward(ctx, x: Tensor, cutoff: float, alpha: float) -> Tensor:
mask1 = (x >= 0) # >=, so there is deriv if x == 0.
p = cutoff - x
mask2 = (p >= 0)
ctx.save_for_backward(mask1, mask2)
ctx.alpha = alpha
return x.relu() + alpha * p.relu()
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None]:
mask1, mask2 = ctx.saved_tensors
return mask1 * ans_grad - (ctx.alpha * mask2) * ans_grad, None, None
class PeLU(torch.nn.Module):
def __init__(self, cutoff: float = -1.0, alpha: float = 0.01) -> None:
super(PeLU, self).__init__()
self.cutoff = cutoff
self.alpha = alpha
def forward(self, x: Tensor) -> Tensor:
return PeLUFunction.apply(x, self.cutoff, self.alpha)

View File

@ -19,6 +19,7 @@ import copy
import math
import warnings
from typing import Optional, Tuple, Sequence
from subsampling import PeLU
import torch
from torch import Tensor, nn
@ -84,12 +85,7 @@ class Conformer(Transformer):
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = nn.LayerNorm(d_model) # TODO: remove.
else:
# Note: TorchScript detects that self.after_norm could be used inside forward()
# and throws an error without this change.
self.after_norm = identity
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
@ -118,9 +114,6 @@ class Conformer(Transformer):
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C)
if self.normalize_before:
x = self.after_norm(x)
logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -163,14 +156,14 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
PeLU(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
PeLU(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
@ -889,7 +882,7 @@ class ConvolutionModule(nn.Module):
padding=0,
bias=bias,
)
self.activation = Swish()
self.activation = PeLU()
def forward(self, x: Tensor) -> Tensor:
"""Compute convolution module.

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/specaugmod_baseline_randcombine1",
default="transducer_stateless/specaugmod_baseline_randcombine1_pelu",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved