From 9d1b4ae04682d12aef2fecb902f318fcf9cab716 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 2 Mar 2022 16:33:27 +0800 Subject: [PATCH] Add pelu to this good-performing setup.. --- .../ASR/conformer_ctc/subsampling.py | 38 ++++++++++++++++++- .../ASR/transducer_stateless/conformer.py | 17 +++------ .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 42 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 542fb0364..b23071926 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -45,13 +45,14 @@ class Conv2dSubsampling(nn.Module): nn.Conv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), - nn.ReLU(), + PeLU(cutoff=-1.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), - nn.ReLU(), + PeLU(cutoff=-5.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -70,6 +71,7 @@ class Conv2dSubsampling(nn.Module): b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) return x @@ -159,3 +161,35 @@ class VggSubsampling(nn.Module): b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) return x + + +class PeLUFunction(torch.autograd.Function): + """ + Computes PeLU function (PeLUFunction.apply(x, cutoff, alpha)). + The function is: + x.relu() + alpha * (cutoff - x).relu() + E.g. consider cutoff = -1, alpha = 0.01. This will tend to prevent die-off + of neurons. + """ + @staticmethod + def forward(ctx, x: Tensor, cutoff: float, alpha: float) -> Tensor: + mask1 = (x >= 0) # >=, so there is deriv if x == 0. + p = cutoff - x + mask2 = (p >= 0) + ctx.save_for_backward(mask1, mask2) + ctx.alpha = alpha + return x.relu() + alpha * p.relu() + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None]: + mask1, mask2 = ctx.saved_tensors + return mask1 * ans_grad - (ctx.alpha * mask2) * ans_grad, None, None + + + +class PeLU(torch.nn.Module): + def __init__(self, cutoff: float = -1.0, alpha: float = 0.01) -> None: + super(PeLU, self).__init__() + self.cutoff = cutoff + self.alpha = alpha + def forward(self, x: Tensor) -> Tensor: + return PeLUFunction.apply(x, self.cutoff, self.alpha) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 327849485..066232a02 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,6 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence +from subsampling import PeLU import torch from torch import Tensor, nn @@ -84,12 +85,7 @@ class Conformer(Transformer): self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) self.normalize_before = normalize_before - if self.normalize_before: - self.after_norm = nn.LayerNorm(d_model) # TODO: remove. - else: - # Note: TorchScript detects that self.after_norm could be used inside forward() - # and throws an error without this change. - self.after_norm = identity + def forward( self, x: torch.Tensor, x_lens: torch.Tensor @@ -118,9 +114,6 @@ class Conformer(Transformer): x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) - if self.normalize_before: - x = self.after_norm(x) - logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -163,14 +156,14 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - Swish(), + PeLU(), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), - Swish(), + PeLU(), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -889,7 +882,7 @@ class ConvolutionModule(nn.Module): padding=0, bias=bias, ) - self.activation = Swish() + self.activation = PeLU() def forward(self, x: Tensor) -> Tensor: """Compute convolution module. diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 8877d4e75..88b366245 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1", + default="transducer_stateless/specaugmod_baseline_randcombine1_pelu", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved