mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement limits on parameter values a different way.
This commit is contained in:
parent
ce4b50d094
commit
ff6431ed0f
@ -158,47 +158,6 @@ class ActivationScaleBalancerFunction(torch.autograd.Function):
|
|||||||
return x_grad - neg_delta_grad, None, None, None,
|
return x_grad - neg_delta_grad, None, None, None,
|
||||||
|
|
||||||
|
|
||||||
class RandomClampFunction(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(
|
|
||||||
ctx,
|
|
||||||
x: Tensor,
|
|
||||||
min: Optional[float],
|
|
||||||
max: Optional[float],
|
|
||||||
prob: float,
|
|
||||||
reflect: float) -> Tensor:
|
|
||||||
kwargs = {}
|
|
||||||
if min is not None:
|
|
||||||
kwargs['min'] = min
|
|
||||||
if max is not None:
|
|
||||||
kwargs['max'] = max
|
|
||||||
x_clamped = torch.clamp(x, **kwargs)
|
|
||||||
mask = torch.rand_like(x) < prob
|
|
||||||
ans = torch.where(mask, x_clamped, x)
|
|
||||||
if x.requires_grad:
|
|
||||||
ctx.save_for_backward(ans == x)
|
|
||||||
ctx.reflect = reflect
|
|
||||||
if reflect != 0.0:
|
|
||||||
ans = ans * (1.0 + reflect) - (x * reflect)
|
|
||||||
return ans
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
|
|
||||||
is_same, = ctx.saved_tensors
|
|
||||||
x_grad = ans_grad * is_same.to(ans_grad.dtype)
|
|
||||||
reflect = ctx.reflect
|
|
||||||
if reflect != 0.0:
|
|
||||||
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
|
|
||||||
return x_grad, None, None, None, None
|
|
||||||
|
|
||||||
def random_clamp(x: Tensor,
|
|
||||||
min: Optional[float] = None,
|
|
||||||
max: Optional[float] = None,
|
|
||||||
prob: float = 0.5,
|
|
||||||
reflect: float = 0.0):
|
|
||||||
return RandomClampFunction.apply(x, min, max, prob, reflect)
|
|
||||||
|
|
||||||
|
|
||||||
def random_cast_to_half(x: Tensor,
|
def random_cast_to_half(x: Tensor,
|
||||||
min_abs: float = 5.0e-06) -> Tensor:
|
min_abs: float = 5.0e-06) -> Tensor:
|
||||||
"""
|
"""
|
||||||
@ -757,6 +716,38 @@ def with_loss(x, y):
|
|||||||
return WithLoss.apply(x, y)
|
return WithLoss.apply(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class LimitParamValue(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: Tensor, min: float, max: float):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
ctx.min = min
|
||||||
|
ctx.max = max
|
||||||
|
return x
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, x_grad: Tensor):
|
||||||
|
x, = ctx.saved_tensors
|
||||||
|
# where x < ctx.min, ensure all grads are negative (this will tend to make
|
||||||
|
# x more positive).
|
||||||
|
x_grad *= torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0)
|
||||||
|
# where x > ctx.max, ensure all grads are positive (this will tend to make
|
||||||
|
# x more negative).
|
||||||
|
x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
|
||||||
|
return x_grad, None, None
|
||||||
|
|
||||||
|
def limit_param_value(x: Tensor,
|
||||||
|
min: float, max: float,
|
||||||
|
prob: float = 0.2):
|
||||||
|
# You apply this to (typically) an nn.Parameter during training to ensure that its
|
||||||
|
# (elements mostly) stays within a supplied range. This is done by modifying the
|
||||||
|
# gradients in backprop.
|
||||||
|
# It's not necessary to do this on every batch: do it only some of the time,
|
||||||
|
# to save a little time.
|
||||||
|
if random.random() < prob:
|
||||||
|
return LimitParamValue.apply(x, min, max)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _no_op(x: Tensor) -> Tensor:
|
def _no_op(x: Tensor) -> Tensor:
|
||||||
if (torch.jit.is_scripting()):
|
if (torch.jit.is_scripting()):
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -34,11 +34,11 @@ from scaling import (
|
|||||||
Whiten,
|
Whiten,
|
||||||
Identity,
|
Identity,
|
||||||
_diag,
|
_diag,
|
||||||
random_clamp,
|
|
||||||
penalize_abs_values_gt,
|
penalize_abs_values_gt,
|
||||||
softmax,
|
softmax,
|
||||||
ScheduledFloat,
|
ScheduledFloat,
|
||||||
FloatLike,
|
FloatLike,
|
||||||
|
limit_param_value,
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -435,13 +435,12 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
def get_bypass_scale(self):
|
def get_bypass_scale(self):
|
||||||
if torch.jit.is_scripting() or not self.training or random.random() < 0.5:
|
if torch.jit.is_scripting() or not self.training:
|
||||||
# the random.random() part is to ensure we get grads if self.bypass_scale becomes out of range
|
|
||||||
return self.bypass_scale
|
return self.bypass_scale
|
||||||
|
else:
|
||||||
return self.bypass_scale.clamp(min=float(self.bypass_clamp_min),
|
return limit_param_value(self.bypass_scale,
|
||||||
max=float(self.bypass_clamp_max))
|
min=float(self.bypass_clamp_min),
|
||||||
|
max=float(self.bypass_clamp_max))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -860,10 +859,10 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
weight1 = self.weight1
|
weight1 = self.weight1
|
||||||
if self.training and random.random() < 0.5 and self.min_weight != (0., 0.):
|
if self.training:
|
||||||
weight1 = weight1.clamp(min=self.min_weight[0],
|
weight1 = limit_param_value(weight1,
|
||||||
max=1.0-self.min_weight[1])
|
min=self.min_weight[0],
|
||||||
|
max=1.0-self.min_weight[1])
|
||||||
|
|
||||||
src1_dim = src1.shape[-1]
|
src1_dim = src1.shape[-1]
|
||||||
src2_dim = src2.shape[-1]
|
src2_dim = src2.shape[-1]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user