Add pelu to this good-performing setup..

This commit is contained in:
Daniel Povey 2022-03-02 16:33:27 +08:00
parent c1063def95
commit 9d1b4ae046
3 changed files with 42 additions and 15 deletions

View File

@ -45,13 +45,14 @@ class Conv2dSubsampling(nn.Module):
nn.Conv2d( nn.Conv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2 in_channels=1, out_channels=odim, kernel_size=3, stride=2
), ),
nn.ReLU(), PeLU(cutoff=-1.0),
nn.Conv2d( nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2 in_channels=odim, out_channels=odim, kernel_size=3, stride=2
), ),
nn.ReLU(), PeLU(cutoff=-5.0),
) )
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
self.out_norm = nn.LayerNorm(odim, elementwise_affine=False)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x. """Subsample x.
@ -70,6 +71,7 @@ class Conv2dSubsampling(nn.Module):
b, c, t, f = x.size() b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out_norm(x)
return x return x
@ -159,3 +161,35 @@ class VggSubsampling(nn.Module):
b, c, t, f = x.size() b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
return x return x
class PeLUFunction(torch.autograd.Function):
"""
Computes PeLU function (PeLUFunction.apply(x, cutoff, alpha)).
The function is:
x.relu() + alpha * (cutoff - x).relu()
E.g. consider cutoff = -1, alpha = 0.01. This will tend to prevent die-off
of neurons.
"""
@staticmethod
def forward(ctx, x: Tensor, cutoff: float, alpha: float) -> Tensor:
mask1 = (x >= 0) # >=, so there is deriv if x == 0.
p = cutoff - x
mask2 = (p >= 0)
ctx.save_for_backward(mask1, mask2)
ctx.alpha = alpha
return x.relu() + alpha * p.relu()
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None]:
mask1, mask2 = ctx.saved_tensors
return mask1 * ans_grad - (ctx.alpha * mask2) * ans_grad, None, None
class PeLU(torch.nn.Module):
def __init__(self, cutoff: float = -1.0, alpha: float = 0.01) -> None:
super(PeLU, self).__init__()
self.cutoff = cutoff
self.alpha = alpha
def forward(self, x: Tensor) -> Tensor:
return PeLUFunction.apply(x, self.cutoff, self.alpha)

View File

@ -19,6 +19,7 @@ import copy
import math import math
import warnings import warnings
from typing import Optional, Tuple, Sequence from typing import Optional, Tuple, Sequence
from subsampling import PeLU
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -84,12 +85,7 @@ class Conformer(Transformer):
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
self.normalize_before = normalize_before self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = nn.LayerNorm(d_model) # TODO: remove.
else:
# Note: TorchScript detects that self.after_norm could be used inside forward()
# and throws an error without this change.
self.after_norm = identity
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor self, x: torch.Tensor, x_lens: torch.Tensor
@ -118,9 +114,6 @@ class Conformer(Transformer):
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C)
if self.normalize_before:
x = self.after_norm(x)
logits = self.encoder_output_layer(x) logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -163,14 +156,14 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
Swish(), PeLU(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )
self.feed_forward_macaron = nn.Sequential( self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
Swish(), PeLU(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )
@ -889,7 +882,7 @@ class ConvolutionModule(nn.Module):
padding=0, padding=0,
bias=bias, bias=bias,
) )
self.activation = Swish() self.activation = PeLU()
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
"""Compute convolution module. """Compute convolution module.

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/specaugmod_baseline_randcombine1", default="transducer_stateless/specaugmod_baseline_randcombine1_pelu",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved