mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Regularize how we apply the min and max to the eps of BasicNorm
This commit is contained in:
parent
a0507a83a5
commit
bf37c7ca85
@ -360,7 +360,11 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
delta = src - src_orig
|
delta = src - src_orig
|
||||||
bypass_scale = self.bypass_scale
|
bypass_scale = self.bypass_scale
|
||||||
if random.random() > 0.1:
|
if self.training and random.random() < 0.25:
|
||||||
|
# with probability 0.25, in training mode, clamp bypass_scale to [
|
||||||
|
# 0.1, 1.0 ]; this will encourage it to learn parameters within this
|
||||||
|
# range by making parameters that are outside that range range
|
||||||
|
# noisy.
|
||||||
bypass_scale = bypass_scale.clamp(min=0.1, max=1.0)
|
bypass_scale = bypass_scale.clamp(min=0.1, max=1.0)
|
||||||
src = src_orig + delta * self.bypass_scale
|
src = src_orig + delta * self.bypass_scale
|
||||||
|
|
||||||
|
|||||||
@ -159,7 +159,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
eps=1.0e-08,
|
eps=1.0e-08,
|
||||||
param_min_rms=1.0e-05,
|
param_min_rms=1.0e-05,
|
||||||
param_max_rms=3.0,
|
param_max_rms=3.0,
|
||||||
scalar_max=5.0,
|
scalar_max=10.0,
|
||||||
size_update_period=4,
|
size_update_period=4,
|
||||||
clipping_update_period=100,
|
clipping_update_period=100,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -348,6 +348,8 @@ class BasicNorm(torch.nn.Module):
|
|||||||
to indicate the connection with conventional LayerNorm.
|
to indicate the connection with conventional LayerNorm.
|
||||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||||
at the initial value.
|
at the initial value.
|
||||||
|
eps_min: float
|
||||||
|
eps_max: float
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -356,6 +358,8 @@ class BasicNorm(torch.nn.Module):
|
|||||||
channel_dim: int = -1, # CAUTION: see documentation.
|
channel_dim: int = -1, # CAUTION: see documentation.
|
||||||
eps: float = 0.25,
|
eps: float = 0.25,
|
||||||
learn_eps: bool = True,
|
learn_eps: bool = True,
|
||||||
|
eps_min: float = -3.0,
|
||||||
|
eps_max: float = 3.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(BasicNorm, self).__init__()
|
super(BasicNorm, self).__init__()
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
@ -364,9 +368,21 @@ class BasicNorm(torch.nn.Module):
|
|||||||
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
||||||
else:
|
else:
|
||||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||||
|
self.eps_min = eps_min
|
||||||
|
self.eps_max = eps_max
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
assert x.shape[self.channel_dim] == self.num_channels
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
|
eps = self.eps
|
||||||
|
if self.training and random.random() < 0.25:
|
||||||
|
# with probability 0.25, in training mode, clamp eps between the min
|
||||||
|
# and max; this will encourage it to learn parameters within the
|
||||||
|
# allowed range by making parameters that are outside the allowed
|
||||||
|
# range noisy.
|
||||||
|
|
||||||
|
# gradients to allow the parameter to get back into the allowed
|
||||||
|
# region if it happens to exit it.
|
||||||
|
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
||||||
scales = (
|
scales = (
|
||||||
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
||||||
+ self.eps.exp()
|
+ self.eps.exp()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user