From 98156711efb11e92d8b50eb426041b62da4a5564 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 19:05:55 +0800 Subject: [PATCH] Introduce in_scale=0.5 for SwishExpScale --- .../ASR/conformer_ctc/subsampling.py | 19 ++++++++++++------- .../ASR/transducer_stateless/conformer.py | 4 ++-- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 6b1cb128f..52a58d104 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -221,18 +221,21 @@ class ExpScale(torch.nn.Module): -def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: +def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float, in_scale: float) -> Tensor: # double-swish, implemented/approximated as offset-swish + if in_scale != 1.0: + x = x * in_scale x = (x * torch.sigmoid(x - 1.0)) x = x * (scale * speed).exp() return x class SwishExpScaleFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: + def forward(ctx, x: Tensor, scale: Tensor, speed: float, in_scale: float) -> Tensor: ctx.save_for_backward(x.detach(), scale.detach()) ctx.speed = speed - return _exp_scale_swish(x, scale, speed) + ctx.in_scale = in_scale + return _exp_scale_swish(x, scale, speed, in_scale) @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: @@ -240,21 +243,23 @@ class SwishExpScaleFunction(torch.autograd.Function): x.requires_grad = True scale.requires_grad = True with torch.enable_grad(): - y = _exp_scale_swish(x, scale, ctx.speed) + y = _exp_scale_swish(x, scale, ctx.speed, ctx.in_scale) y.backward(gradient=y_grad) - return x.grad, scale.grad, None + return x.grad, scale.grad, None, None class SwishExpScale(torch.nn.Module): # combines ExpScale and a Swish (actually the ExpScale is after the Swish). # caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0) - def __init__(self, *shape, speed: float = 1.0): + # + def __init__(self, *shape, speed: float = 1.0, in_scale: float = 1.0): super(SwishExpScale, self).__init__() + self.in_scale = in_scale self.scale = nn.Parameter(torch.zeros(*shape)) self.speed = speed def forward(self, x: Tensor) -> Tensor: - return SwishExpScaleFunction.apply(x, self.scale, self.speed) + return SwishExpScaleFunction.apply(x, self.scale, self.speed, self.in_scale) # x = (x * torch.sigmoid(x)) # x = (x * torch.sigmoid(x)) # x = x * (self.scale * self.speed).exp() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 2c602bbea..7b9aff71f 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -160,7 +160,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.01), - SwishExpScale(dim_feedforward, speed=20.0), + SwishExpScale(dim_feedforward, speed=20.0, in_scale=0.5), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -169,7 +169,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.01), - SwishExpScale(dim_feedforward, speed=20.0), + SwishExpScale(dim_feedforward, speed=20.0, in_scale=0.5), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index b5e9e846f..190406491 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework", + default="transducer_stateless/randcombine1_expscale3_rework_0.5", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved