Merge branch 'scaled_adam_exp146' into scaled_adam_exp149

This commit is contained in:
Daniel Povey 2022-10-19 19:16:27 +08:00
commit ef5a27388f
3 changed files with 30 additions and 9 deletions

View File

@ -870,8 +870,6 @@ class RelPositionMultiheadAttention(nn.Module):
self.copy_pos_query = Identity() self.copy_pos_query = Identity()
self.copy_query = Identity() self.copy_query = Identity()
self.in_balancer = ActivationBalancer(3 * attention_dim,
channel_dim=-1, max_abs=5.0)
self.out_proj = ScaledLinear( self.out_proj = ScaledLinear(
attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 attention_dim // 2, embed_dim, bias=True, initial_scale=0.05
) )
@ -931,7 +929,7 @@ class RelPositionMultiheadAttention(nn.Module):
and S is the sequence length. and S is the sequence length.
""" """
x, weights = self.multi_head_attention_forward( x, weights = self.multi_head_attention_forward(
self.in_balancer(self.in_proj(x)), self.in_proj(x),
self.linear_pos(pos_emb), self.linear_pos(pos_emb),
self.attention_dim, self.attention_dim,
self.num_heads, self.num_heads,
@ -1117,7 +1115,8 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output_weights = random_clamp(attn_output_weights, attn_output_weights = random_clamp(attn_output_weights,
min=-attn_weights_max, min=-attn_weights_max,
max=attn_weights_max, max=attn_weights_max,
prob=0.5) prob=0.5,
reflect=0.1)
# attn_output_weights: (batch, head, time1, time2) # attn_output_weights: (batch, head, time1, time2)

View File

@ -19,10 +19,12 @@ import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import random_clamp
from icefall.utils import add_sos from icefall.utils import add_sos
class Transducer(nn.Module): class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf """It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks" "Sequence Transduction with Recurrent Neural Networks"
@ -140,6 +142,12 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
if self.training:
lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5,
reflect=0.1)
am = random_clamp(am, min=-5.0, max=5.0, prob=0.5,
reflect=0.1)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
@ -175,6 +183,10 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
if self.training:
logits = random_clamp(logits, -8.0, 2.0, prob=0.5,
reflect=0.1)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),

View File

@ -165,24 +165,34 @@ class RandomClampFunction(torch.autograd.Function):
x: Tensor, x: Tensor,
min: Optional[float], min: Optional[float],
max: Optional[float], max: Optional[float],
prob: float) -> Tensor: prob: float,
reflect: float) -> Tensor:
x_clamped = torch.clamp(x, min=min, max=max) x_clamped = torch.clamp(x, min=min, max=max)
mask = torch.rand_like(x) < prob mask = torch.rand_like(x) < prob
ans = torch.where(mask, x_clamped, x) ans = torch.where(mask, x_clamped, x)
if x.requires_grad: if x.requires_grad:
ctx.save_for_backward(ans == x) ctx.save_for_backward(ans == x)
ctx.reflect = reflect
if reflect != 0.0:
ans = ans * (1.0 + reflect) - (x * reflect)
return ans return ans
@staticmethod @staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None]: def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
is_same, = ctx.saved_tensors is_same, = ctx.saved_tensors
return ans_grad * is_same.to(ans_grad.dtype), None, None, None x_grad = ans_grad * is_same.to(ans_grad.dtype)
reflect = ctx.reflect
if reflect != 0.0:
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
return ans_grad * is_same.to(ans_grad.dtype), None, None, None, None
def random_clamp(x: Tensor, def random_clamp(x: Tensor,
min: Optional[float] = None, min: Optional[float] = None,
max: Optional[float] = None, max: Optional[float] = None,
prob: float = 0.5): prob: float = 0.5,
return RandomClampFunction.apply(x, min, max, prob) reflect: float = 0.0):
return RandomClampFunction.apply(x, min, max, prob, reflect)
def random_cast_to_half(x: Tensor, def random_cast_to_half(x: Tensor,