mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement limits on parameter values a different way.
This commit is contained in:
parent
ce4b50d094
commit
ff6431ed0f
@ -158,47 +158,6 @@ class ActivationScaleBalancerFunction(torch.autograd.Function):
|
||||
return x_grad - neg_delta_grad, None, None, None,
|
||||
|
||||
|
||||
class RandomClampFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
min: Optional[float],
|
||||
max: Optional[float],
|
||||
prob: float,
|
||||
reflect: float) -> Tensor:
|
||||
kwargs = {}
|
||||
if min is not None:
|
||||
kwargs['min'] = min
|
||||
if max is not None:
|
||||
kwargs['max'] = max
|
||||
x_clamped = torch.clamp(x, **kwargs)
|
||||
mask = torch.rand_like(x) < prob
|
||||
ans = torch.where(mask, x_clamped, x)
|
||||
if x.requires_grad:
|
||||
ctx.save_for_backward(ans == x)
|
||||
ctx.reflect = reflect
|
||||
if reflect != 0.0:
|
||||
ans = ans * (1.0 + reflect) - (x * reflect)
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
|
||||
is_same, = ctx.saved_tensors
|
||||
x_grad = ans_grad * is_same.to(ans_grad.dtype)
|
||||
reflect = ctx.reflect
|
||||
if reflect != 0.0:
|
||||
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
|
||||
return x_grad, None, None, None, None
|
||||
|
||||
def random_clamp(x: Tensor,
|
||||
min: Optional[float] = None,
|
||||
max: Optional[float] = None,
|
||||
prob: float = 0.5,
|
||||
reflect: float = 0.0):
|
||||
return RandomClampFunction.apply(x, min, max, prob, reflect)
|
||||
|
||||
|
||||
def random_cast_to_half(x: Tensor,
|
||||
min_abs: float = 5.0e-06) -> Tensor:
|
||||
"""
|
||||
@ -757,6 +716,38 @@ def with_loss(x, y):
|
||||
return WithLoss.apply(x, y)
|
||||
|
||||
|
||||
class LimitParamValue(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, min: float, max: float):
|
||||
ctx.save_for_backward(x)
|
||||
ctx.min = min
|
||||
ctx.max = max
|
||||
return x
|
||||
@staticmethod
|
||||
def backward(ctx, x_grad: Tensor):
|
||||
x, = ctx.saved_tensors
|
||||
# where x < ctx.min, ensure all grads are negative (this will tend to make
|
||||
# x more positive).
|
||||
x_grad *= torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0)
|
||||
# where x > ctx.max, ensure all grads are positive (this will tend to make
|
||||
# x more negative).
|
||||
x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
|
||||
return x_grad, None, None
|
||||
|
||||
def limit_param_value(x: Tensor,
|
||||
min: float, max: float,
|
||||
prob: float = 0.2):
|
||||
# You apply this to (typically) an nn.Parameter during training to ensure that its
|
||||
# (elements mostly) stays within a supplied range. This is done by modifying the
|
||||
# gradients in backprop.
|
||||
# It's not necessary to do this on every batch: do it only some of the time,
|
||||
# to save a little time.
|
||||
if random.random() < prob:
|
||||
return LimitParamValue.apply(x, min, max)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def _no_op(x: Tensor) -> Tensor:
|
||||
if (torch.jit.is_scripting()):
|
||||
return x
|
||||
|
||||
@ -34,11 +34,11 @@ from scaling import (
|
||||
Whiten,
|
||||
Identity,
|
||||
_diag,
|
||||
random_clamp,
|
||||
penalize_abs_values_gt,
|
||||
softmax,
|
||||
ScheduledFloat,
|
||||
FloatLike,
|
||||
limit_param_value,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -435,13 +435,12 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
grad_scale=0.01)
|
||||
|
||||
def get_bypass_scale(self):
|
||||
if torch.jit.is_scripting() or not self.training or random.random() < 0.5:
|
||||
# the random.random() part is to ensure we get grads if self.bypass_scale becomes out of range
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
return self.bypass_scale
|
||||
|
||||
return self.bypass_scale.clamp(min=float(self.bypass_clamp_min),
|
||||
max=float(self.bypass_clamp_max))
|
||||
|
||||
else:
|
||||
return limit_param_value(self.bypass_scale,
|
||||
min=float(self.bypass_clamp_min),
|
||||
max=float(self.bypass_clamp_max))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -860,10 +859,10 @@ class SimpleCombiner(torch.nn.Module):
|
||||
|
||||
|
||||
weight1 = self.weight1
|
||||
if self.training and random.random() < 0.5 and self.min_weight != (0., 0.):
|
||||
weight1 = weight1.clamp(min=self.min_weight[0],
|
||||
max=1.0-self.min_weight[1])
|
||||
|
||||
if self.training:
|
||||
weight1 = limit_param_value(weight1,
|
||||
min=self.min_weight[0],
|
||||
max=1.0-self.min_weight[1])
|
||||
|
||||
src1_dim = src1.shape[-1]
|
||||
src2_dim = src2.shape[-1]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user