mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Add deriv-balancing code
This commit is contained in:
parent
eb3ed54202
commit
6252282fd0
@ -47,11 +47,15 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||||
),
|
),
|
||||||
|
DerivBalancer(channel_dim=1, threshold=0.02,
|
||||||
|
max_factor=0.02),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
ExpScale(odim, 1, 1, speed=20.0),
|
ExpScale(odim, 1, 1, speed=20.0),
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||||
),
|
),
|
||||||
|
DerivBalancer(channel_dim=1, threshold=0.02,
|
||||||
|
max_factor=0.02),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
ExpScale(odim, 1, 1, speed=20.0),
|
ExpScale(odim, 1, 1, speed=20.0),
|
||||||
)
|
)
|
||||||
@ -248,6 +252,68 @@ class ExpScaleSwish(torch.nn.Module):
|
|||||||
# return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp()
|
# return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp()
|
||||||
# return x * (self.scale * self.speed).exp()
|
# return x * (self.scale * self.speed).exp()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DerivBalancerFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: Tensor, channel_dim: int,
|
||||||
|
threshold: 0.05, max_factor: 0.05,
|
||||||
|
epsilon: 1.0e-10) -> Tensor:
|
||||||
|
if x.requires_grad:
|
||||||
|
if channel_dim < 0:
|
||||||
|
channel_dim += x.ndim
|
||||||
|
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||||
|
proportion_positive = torch.mean((x > 0).to(x.dtype), dim=sum_dims, keepdim=True)
|
||||||
|
factor = (threshold - proportion_positive).relu() * (max_factor / threshold)
|
||||||
|
|
||||||
|
ctx.save_for_backward(factor)
|
||||||
|
ctx.epsilon = epsilon
|
||||||
|
ctx.sum_dims = sum_dims
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
|
||||||
|
factor, = ctx.saved_tensors
|
||||||
|
neg_delta_grad = x_grad.abs() * factor
|
||||||
|
if ctx.epsilon != 0.0:
|
||||||
|
sum_abs_grad = torch.sum(x_grad.abs(), dim=ctx.sum_dims, keepdim=True)
|
||||||
|
deriv_is_zero = (sum_abs_grad == 0.0)
|
||||||
|
neg_delta_grad += ctx.epsilon * deriv_is_zero
|
||||||
|
|
||||||
|
return x_grad - neg_delta_grad, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DerivBalancer(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Modifies the backpropped derivatives of a function to try to encourage, for
|
||||||
|
each channel, that it is positive at least a proportion `threshold` of the
|
||||||
|
time. It does this by multiplying negative derivative values by up to
|
||||||
|
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
||||||
|
interpolated from 0 at the threshold to those extremal values when none
|
||||||
|
of the inputs are positive.
|
||||||
|
|
||||||
|
When all grads are zero for a channel, this
|
||||||
|
module sets all the input derivatives for that channel to -epsilon; the
|
||||||
|
idea is to bring completely dead neurons back to life this way.
|
||||||
|
"""
|
||||||
|
def __init__(self, channel_dim: int,
|
||||||
|
threshold: float = 0.05,
|
||||||
|
max_factor: float = 0.05,
|
||||||
|
epsilon: float = 1.0e-10):
|
||||||
|
super(DerivBalancer, self).__init__()
|
||||||
|
self.channel_dim = channel_dim
|
||||||
|
self.threshold = threshold
|
||||||
|
self.max_factor = max_factor
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold,
|
||||||
|
self.max_factor, self.epsilon)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_exp_scale_swish():
|
def _test_exp_scale_swish():
|
||||||
class Swish(torch.nn.Module):
|
class Swish(torch.nn.Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
@ -271,5 +337,26 @@ def _test_exp_scale_swish():
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _test_deriv_balancer():
|
||||||
|
channel_dim = 0
|
||||||
|
probs = torch.arange(0, 1, 0.01)
|
||||||
|
N = 500
|
||||||
|
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
||||||
|
x = x.detach()
|
||||||
|
x.requires_grad = True
|
||||||
|
m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, epsilon=1.0e-10)
|
||||||
|
|
||||||
|
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
||||||
|
y_grad[-1,:] = 0
|
||||||
|
|
||||||
|
y = m(x)
|
||||||
|
y.backward(gradient=y_grad)
|
||||||
|
print("x = ", x)
|
||||||
|
print("y grad = ", y_grad)
|
||||||
|
print("x grad = ", x.grad)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
_test_deriv_balancer()
|
||||||
_test_exp_scale_swish()
|
_test_exp_scale_swish()
|
||||||
|
@ -19,7 +19,7 @@ import copy
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple, Sequence
|
from typing import Optional, Tuple, Sequence
|
||||||
from subsampling import PeLU, ExpScale, ExpScaleSwish
|
from subsampling import PeLU, ExpScale, ExpScaleSwish, DerivBalancer
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -156,6 +156,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
|
DerivBalancer(channel_dim=-1, threshold=0.02,
|
||||||
|
max_factor=0.02),
|
||||||
ExpScaleSwish(dim_feedforward, speed=20.0),
|
ExpScaleSwish(dim_feedforward, speed=20.0),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(dim_feedforward, d_model),
|
nn.Linear(dim_feedforward, d_model),
|
||||||
@ -163,6 +165,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
|
DerivBalancer(channel_dim=-1, threshold=0.02,
|
||||||
|
max_factor=0.02),
|
||||||
ExpScaleSwish(dim_feedforward, speed=20.0),
|
ExpScaleSwish(dim_feedforward, speed=20.0),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(dim_feedforward, d_model),
|
nn.Linear(dim_feedforward, d_model),
|
||||||
|
@ -110,7 +110,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5",
|
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
Loading…
x
Reference in New Issue
Block a user