mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add reflect=0.1 to invocations of random_clamp()
This commit is contained in:
parent
8e15d4312a
commit
f4442de1c4
@ -1116,7 +1116,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
attn_output_weights = random_clamp(attn_output_weights,
|
||||
min=-attn_weights_max,
|
||||
max=attn_weights_max,
|
||||
prob=0.5)
|
||||
prob=0.5,
|
||||
reflect=0.1)
|
||||
|
||||
# attn_output_weights: (batch, head, time1, time2)
|
||||
|
||||
|
||||
@ -143,8 +143,10 @@ class Transducer(nn.Module):
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
if self.training:
|
||||
lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5)
|
||||
am = random_clamp(am, min=-5.0, max=5.0, prob=0.5)
|
||||
lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5,
|
||||
reflect=0.1)
|
||||
am = random_clamp(am, min=-5.0, max=5.0, prob=0.5,
|
||||
reflect=0.1)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
@ -182,7 +184,8 @@ class Transducer(nn.Module):
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
if self.training:
|
||||
logits = random_clamp(logits, -8.0, 2.0, prob=0.5)
|
||||
logits = random_clamp(logits, -8.0, 2.0, prob=0.5,
|
||||
reflect=0.1)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
|
||||
@ -165,24 +165,34 @@ class RandomClampFunction(torch.autograd.Function):
|
||||
x: Tensor,
|
||||
min: Optional[float],
|
||||
max: Optional[float],
|
||||
prob: float) -> Tensor:
|
||||
prob: float,
|
||||
reflect: float) -> Tensor:
|
||||
x_clamped = torch.clamp(x, min=min, max=max)
|
||||
mask = torch.rand_like(x) < prob
|
||||
ans = torch.where(mask, x_clamped, x)
|
||||
if x.requires_grad:
|
||||
ctx.save_for_backward(ans == x)
|
||||
ctx.reflect = reflect
|
||||
if reflect != 0.0:
|
||||
ans = ans * (1.0 + reflect) - (x * reflect)
|
||||
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
|
||||
is_same, = ctx.saved_tensors
|
||||
return ans_grad * is_same.to(ans_grad.dtype), None, None, None
|
||||
x_grad = ans_grad * is_same.to(ans_grad.dtype)
|
||||
reflect = ctx.reflect
|
||||
if reflect != 0.0:
|
||||
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
|
||||
return ans_grad * is_same.to(ans_grad.dtype), None, None, None, None
|
||||
|
||||
def random_clamp(x: Tensor,
|
||||
min: Optional[float] = None,
|
||||
max: Optional[float] = None,
|
||||
prob: float = 0.5):
|
||||
return RandomClampFunction.apply(x, min, max, prob)
|
||||
prob: float = 0.5,
|
||||
reflect: float = 0.0):
|
||||
return RandomClampFunction.apply(x, min, max, prob, reflect)
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user