mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp146' into scaled_adam_exp149
This commit is contained in:
commit
ef5a27388f
@ -870,8 +870,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self.copy_pos_query = Identity()
|
self.copy_pos_query = Identity()
|
||||||
self.copy_query = Identity()
|
self.copy_query = Identity()
|
||||||
|
|
||||||
self.in_balancer = ActivationBalancer(3 * attention_dim,
|
|
||||||
channel_dim=-1, max_abs=5.0)
|
|
||||||
self.out_proj = ScaledLinear(
|
self.out_proj = ScaledLinear(
|
||||||
attention_dim // 2, embed_dim, bias=True, initial_scale=0.05
|
attention_dim // 2, embed_dim, bias=True, initial_scale=0.05
|
||||||
)
|
)
|
||||||
@ -931,7 +929,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
and S is the sequence length.
|
and S is the sequence length.
|
||||||
"""
|
"""
|
||||||
x, weights = self.multi_head_attention_forward(
|
x, weights = self.multi_head_attention_forward(
|
||||||
self.in_balancer(self.in_proj(x)),
|
self.in_proj(x),
|
||||||
self.linear_pos(pos_emb),
|
self.linear_pos(pos_emb),
|
||||||
self.attention_dim,
|
self.attention_dim,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
@ -1117,7 +1115,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
attn_output_weights = random_clamp(attn_output_weights,
|
attn_output_weights = random_clamp(attn_output_weights,
|
||||||
min=-attn_weights_max,
|
min=-attn_weights_max,
|
||||||
max=attn_weights_max,
|
max=attn_weights_max,
|
||||||
prob=0.5)
|
prob=0.5,
|
||||||
|
reflect=0.1)
|
||||||
|
|
||||||
# attn_output_weights: (batch, head, time1, time2)
|
# attn_output_weights: (batch, head, time1, time2)
|
||||||
|
|
||||||
|
|||||||
@ -19,10 +19,12 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
from scaling import random_clamp
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||||
"Sequence Transduction with Recurrent Neural Networks"
|
"Sequence Transduction with Recurrent Neural Networks"
|
||||||
@ -140,6 +142,12 @@ class Transducer(nn.Module):
|
|||||||
lm = self.simple_lm_proj(decoder_out)
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
am = self.simple_am_proj(encoder_out)
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5,
|
||||||
|
reflect=0.1)
|
||||||
|
am = random_clamp(am, min=-5.0, max=5.0, prob=0.5,
|
||||||
|
reflect=0.1)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=lm.float(),
|
lm=lm.float(),
|
||||||
@ -175,6 +183,10 @@ class Transducer(nn.Module):
|
|||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
logits = random_clamp(logits, -8.0, 2.0, prob=0.5,
|
||||||
|
reflect=0.1)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits=logits.float(),
|
logits=logits.float(),
|
||||||
|
|||||||
@ -165,24 +165,34 @@ class RandomClampFunction(torch.autograd.Function):
|
|||||||
x: Tensor,
|
x: Tensor,
|
||||||
min: Optional[float],
|
min: Optional[float],
|
||||||
max: Optional[float],
|
max: Optional[float],
|
||||||
prob: float) -> Tensor:
|
prob: float,
|
||||||
|
reflect: float) -> Tensor:
|
||||||
x_clamped = torch.clamp(x, min=min, max=max)
|
x_clamped = torch.clamp(x, min=min, max=max)
|
||||||
mask = torch.rand_like(x) < prob
|
mask = torch.rand_like(x) < prob
|
||||||
ans = torch.where(mask, x_clamped, x)
|
ans = torch.where(mask, x_clamped, x)
|
||||||
if x.requires_grad:
|
if x.requires_grad:
|
||||||
ctx.save_for_backward(ans == x)
|
ctx.save_for_backward(ans == x)
|
||||||
|
ctx.reflect = reflect
|
||||||
|
if reflect != 0.0:
|
||||||
|
ans = ans * (1.0 + reflect) - (x * reflect)
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
|
||||||
is_same, = ctx.saved_tensors
|
is_same, = ctx.saved_tensors
|
||||||
return ans_grad * is_same.to(ans_grad.dtype), None, None, None
|
x_grad = ans_grad * is_same.to(ans_grad.dtype)
|
||||||
|
reflect = ctx.reflect
|
||||||
|
if reflect != 0.0:
|
||||||
|
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
|
||||||
|
return ans_grad * is_same.to(ans_grad.dtype), None, None, None, None
|
||||||
|
|
||||||
def random_clamp(x: Tensor,
|
def random_clamp(x: Tensor,
|
||||||
min: Optional[float] = None,
|
min: Optional[float] = None,
|
||||||
max: Optional[float] = None,
|
max: Optional[float] = None,
|
||||||
prob: float = 0.5):
|
prob: float = 0.5,
|
||||||
return RandomClampFunction.apply(x, min, max, prob)
|
reflect: float = 0.0):
|
||||||
|
return RandomClampFunction.apply(x, min, max, prob, reflect)
|
||||||
|
|
||||||
|
|
||||||
def random_cast_to_half(x: Tensor,
|
def random_cast_to_half(x: Tensor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user