Implement limits on parameter values a different way.

This commit is contained in:
Daniel Povey 2022-11-14 16:02:38 +08:00
parent ce4b50d094
commit ff6431ed0f
2 changed files with 42 additions and 52 deletions

View File

@ -158,47 +158,6 @@ class ActivationScaleBalancerFunction(torch.autograd.Function):
return x_grad - neg_delta_grad, None, None, None, return x_grad - neg_delta_grad, None, None, None,
class RandomClampFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
min: Optional[float],
max: Optional[float],
prob: float,
reflect: float) -> Tensor:
kwargs = {}
if min is not None:
kwargs['min'] = min
if max is not None:
kwargs['max'] = max
x_clamped = torch.clamp(x, **kwargs)
mask = torch.rand_like(x) < prob
ans = torch.where(mask, x_clamped, x)
if x.requires_grad:
ctx.save_for_backward(ans == x)
ctx.reflect = reflect
if reflect != 0.0:
ans = ans * (1.0 + reflect) - (x * reflect)
return ans
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
is_same, = ctx.saved_tensors
x_grad = ans_grad * is_same.to(ans_grad.dtype)
reflect = ctx.reflect
if reflect != 0.0:
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
return x_grad, None, None, None, None
def random_clamp(x: Tensor,
min: Optional[float] = None,
max: Optional[float] = None,
prob: float = 0.5,
reflect: float = 0.0):
return RandomClampFunction.apply(x, min, max, prob, reflect)
def random_cast_to_half(x: Tensor, def random_cast_to_half(x: Tensor,
min_abs: float = 5.0e-06) -> Tensor: min_abs: float = 5.0e-06) -> Tensor:
""" """
@ -757,6 +716,38 @@ def with_loss(x, y):
return WithLoss.apply(x, y) return WithLoss.apply(x, y)
class LimitParamValue(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, min: float, max: float):
ctx.save_for_backward(x)
ctx.min = min
ctx.max = max
return x
@staticmethod
def backward(ctx, x_grad: Tensor):
x, = ctx.saved_tensors
# where x < ctx.min, ensure all grads are negative (this will tend to make
# x more positive).
x_grad *= torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0)
# where x > ctx.max, ensure all grads are positive (this will tend to make
# x more negative).
x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
return x_grad, None, None
def limit_param_value(x: Tensor,
min: float, max: float,
prob: float = 0.2):
# You apply this to (typically) an nn.Parameter during training to ensure that its
# (elements mostly) stays within a supplied range. This is done by modifying the
# gradients in backprop.
# It's not necessary to do this on every batch: do it only some of the time,
# to save a little time.
if random.random() < prob:
return LimitParamValue.apply(x, min, max)
else:
return x
def _no_op(x: Tensor) -> Tensor: def _no_op(x: Tensor) -> Tensor:
if (torch.jit.is_scripting()): if (torch.jit.is_scripting()):
return x return x

View File

@ -34,11 +34,11 @@ from scaling import (
Whiten, Whiten,
Identity, Identity,
_diag, _diag,
random_clamp,
penalize_abs_values_gt, penalize_abs_values_gt,
softmax, softmax,
ScheduledFloat, ScheduledFloat,
FloatLike, FloatLike,
limit_param_value,
) )
from torch import Tensor, nn from torch import Tensor, nn
@ -435,13 +435,12 @@ class ZipformerEncoderLayer(nn.Module):
grad_scale=0.01) grad_scale=0.01)
def get_bypass_scale(self): def get_bypass_scale(self):
if torch.jit.is_scripting() or not self.training or random.random() < 0.5: if torch.jit.is_scripting() or not self.training:
# the random.random() part is to ensure we get grads if self.bypass_scale becomes out of range
return self.bypass_scale return self.bypass_scale
else:
return self.bypass_scale.clamp(min=float(self.bypass_clamp_min), return limit_param_value(self.bypass_scale,
max=float(self.bypass_clamp_max)) min=float(self.bypass_clamp_min),
max=float(self.bypass_clamp_max))
def forward( def forward(
self, self,
@ -860,10 +859,10 @@ class SimpleCombiner(torch.nn.Module):
weight1 = self.weight1 weight1 = self.weight1
if self.training and random.random() < 0.5 and self.min_weight != (0., 0.): if self.training:
weight1 = weight1.clamp(min=self.min_weight[0], weight1 = limit_param_value(weight1,
max=1.0-self.min_weight[1]) min=self.min_weight[0],
max=1.0-self.min_weight[1])
src1_dim = src1.shape[-1] src1_dim = src1.shape[-1]
src2_dim = src2.shape[-1] src2_dim = src2.shape[-1]