mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp150' into scaled_adam_exp155
# Conflicts: # egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py
This commit is contained in:
commit
6e6209419c
@ -36,6 +36,8 @@ from scaling import (
|
|||||||
_diag,
|
_diag,
|
||||||
random_clamp,
|
random_clamp,
|
||||||
with_loss,
|
with_loss,
|
||||||
|
softmax,
|
||||||
|
RandomGrad,
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -304,7 +306,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
whitening_limit=5.0,
|
whitening_limit=5.0,
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
self.random_grad = RandomGrad()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -364,7 +366,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
bypass_scale = bypass_scale.clamp(min=0.1, max=1.0)
|
bypass_scale = bypass_scale.clamp(min=0.1, max=1.0)
|
||||||
src = src_orig + delta * self.bypass_scale
|
src = src_orig + delta * self.bypass_scale
|
||||||
|
|
||||||
return self.whiten(src)
|
return self.random_grad(self.whiten(src))
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoder(nn.Module):
|
class ConformerEncoder(nn.Module):
|
||||||
@ -870,8 +872,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self.copy_pos_query = Identity()
|
self.copy_pos_query = Identity()
|
||||||
self.copy_query = Identity()
|
self.copy_query = Identity()
|
||||||
|
|
||||||
self.in_balancer = ActivationBalancer(3 * attention_dim,
|
|
||||||
channel_dim=-1, max_abs=5.0)
|
|
||||||
self.out_proj = ScaledLinear(
|
self.out_proj = ScaledLinear(
|
||||||
attention_dim // 2, embed_dim, bias=True, initial_scale=0.05
|
attention_dim // 2, embed_dim, bias=True, initial_scale=0.05
|
||||||
)
|
)
|
||||||
@ -931,7 +931,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
and S is the sequence length.
|
and S is the sequence length.
|
||||||
"""
|
"""
|
||||||
x, weights = self.multi_head_attention_forward(
|
x, weights = self.multi_head_attention_forward(
|
||||||
self.in_balancer(self.in_proj(x)),
|
self.in_proj(x),
|
||||||
self.linear_pos(pos_emb),
|
self.linear_pos(pos_emb),
|
||||||
self.attention_dim,
|
self.attention_dim,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
@ -1121,7 +1121,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
attn_output_weights = random_clamp(attn_output_weights,
|
attn_output_weights = random_clamp(attn_output_weights,
|
||||||
min=-attn_weights_max,
|
min=-attn_weights_max,
|
||||||
max=attn_weights_max,
|
max=attn_weights_max,
|
||||||
prob=0.5)
|
prob=0.5,
|
||||||
|
reflect=0.1)
|
||||||
|
|
||||||
if training and random.random() < 0.1:
|
if training and random.random() < 0.1:
|
||||||
# This is a harder way of limiting the attention scores to not be too large.
|
# This is a harder way of limiting the attention scores to not be too large.
|
||||||
@ -1170,7 +1171,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
bsz * num_heads, seq_len, seq_len
|
bsz * num_heads, seq_len, seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
||||||
attn_output_weights = nn.functional.dropout(
|
attn_output_weights = nn.functional.dropout(
|
||||||
attn_output_weights, p=dropout_p, training=training
|
attn_output_weights, p=dropout_p, training=training
|
||||||
)
|
)
|
||||||
@ -1583,7 +1584,7 @@ class AttentionCombine(nn.Module):
|
|||||||
single_prob_mask)
|
single_prob_mask)
|
||||||
|
|
||||||
weights = weights.masked_fill(mask, float('-inf'))
|
weights = weights.masked_fill(mask, float('-inf'))
|
||||||
weights = weights.softmax(dim=1)
|
weights = softmax(weights, dim=1)
|
||||||
|
|
||||||
# (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1),
|
# (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1),
|
||||||
ans = torch.matmul(stacked_inputs, weights.unsqueeze(2))
|
ans = torch.matmul(stacked_inputs, weights.unsqueeze(2))
|
||||||
|
|||||||
@ -165,26 +165,125 @@ class RandomClampFunction(torch.autograd.Function):
|
|||||||
x: Tensor,
|
x: Tensor,
|
||||||
min: Optional[float],
|
min: Optional[float],
|
||||||
max: Optional[float],
|
max: Optional[float],
|
||||||
prob: float) -> Tensor:
|
prob: float,
|
||||||
|
reflect: float) -> Tensor:
|
||||||
x_clamped = torch.clamp(x, min=min, max=max)
|
x_clamped = torch.clamp(x, min=min, max=max)
|
||||||
mask = torch.rand_like(x) < prob
|
mask = torch.rand_like(x) < prob
|
||||||
ans = torch.where(mask, x_clamped, x)
|
ans = torch.where(mask, x_clamped, x)
|
||||||
if x.requires_grad:
|
if x.requires_grad:
|
||||||
ctx.save_for_backward(ans == x)
|
ctx.save_for_backward(ans == x)
|
||||||
|
ctx.reflect = reflect
|
||||||
|
if reflect != 0.0:
|
||||||
|
ans = ans * (1.0 + reflect) - (x * reflect)
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
|
||||||
is_same, = ctx.saved_tensors
|
is_same, = ctx.saved_tensors
|
||||||
return ans_grad * is_same.to(ans_grad.dtype), None, None, None
|
x_grad = ans_grad * is_same.to(ans_grad.dtype)
|
||||||
|
reflect = ctx.reflect
|
||||||
|
if reflect != 0.0:
|
||||||
|
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
|
||||||
|
return x_grad, None, None, None, None
|
||||||
|
|
||||||
def random_clamp(x: Tensor,
|
def random_clamp(x: Tensor,
|
||||||
min: Optional[float] = None,
|
min: Optional[float] = None,
|
||||||
max: Optional[float] = None,
|
max: Optional[float] = None,
|
||||||
prob: float = 0.5):
|
prob: float = 0.5,
|
||||||
return RandomClampFunction.apply(x, min, max, prob)
|
reflect: float = 0.0):
|
||||||
|
return RandomClampFunction.apply(x, min, max, prob, reflect)
|
||||||
|
|
||||||
|
|
||||||
|
def random_cast_to_half(x: Tensor,
|
||||||
|
min_abs: float = 5.0e-06) -> Tensor:
|
||||||
|
"""
|
||||||
|
A randomized way of casting a floating point value to half precision.
|
||||||
|
"""
|
||||||
|
if x.dtype == torch.float16:
|
||||||
|
return x
|
||||||
|
x_sign = x.sign()
|
||||||
|
x_abs = x.abs()
|
||||||
|
is_too_small = (x_abs < min_abs)
|
||||||
|
# for elements where is_too_small is true, random_val will contain +-min_abs with
|
||||||
|
# probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
|
||||||
|
# for those elements].
|
||||||
|
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
|
||||||
|
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomGradFunction(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Does nothing in forward pass; in backward pass, gets rid of very small grads using
|
||||||
|
randomized approach that preserves expectations (intended to reduce roundoff).
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
|
||||||
|
ctx.min_abs = min_abs
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
|
||||||
|
min_abs = ctx.min_abs
|
||||||
|
if ans_grad.dtype == torch.float16:
|
||||||
|
return random_cast_to_half(ans_grad.to(torch.float32),
|
||||||
|
min_abs=ctx.min_abs), None
|
||||||
|
else:
|
||||||
|
return ans_grad, None
|
||||||
|
|
||||||
|
class RandomGrad(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Gets rid of very small gradients using an expectation-preserving method, intended to increase
|
||||||
|
accuracy of training when using amp (automatic mixed precision)
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
min_abs: float = 5.0e-06):
|
||||||
|
super(RandomGrad, self).__init__()
|
||||||
|
self.min_abs = min_abs
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: Tensor):
|
||||||
|
if torch.jit.is_scripting() or not self.training:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return RandomGradFunction.apply(x, self.min_abs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SoftmaxFunction(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Tries to handle half-precision derivatives in a randomized way that should
|
||||||
|
be more accurate for training than the default behavior.
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: Tensor, dim: int):
|
||||||
|
ans = x.softmax(dim=dim)
|
||||||
|
# if x dtype is float16, x.softmax() returns a float32 because
|
||||||
|
# (presumably) that op does not support float16, and autocast
|
||||||
|
# is enabled.
|
||||||
|
ctx.save_for_backward(ans)
|
||||||
|
ctx.x_dtype = x.dtype
|
||||||
|
ctx.dim = dim
|
||||||
|
return ans
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, ans_grad: Tensor):
|
||||||
|
ans, = ctx.saved_tensors
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
ans_grad = ans_grad.to(torch.float32)
|
||||||
|
ans = ans.to(torch.float32)
|
||||||
|
x_grad = ans_grad * ans
|
||||||
|
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||||
|
if ctx.x_dtype == torch.float16:
|
||||||
|
x_grad = random_cast_to_half(x_grad)
|
||||||
|
|
||||||
|
return x_grad, None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(x: Tensor,
|
||||||
|
dim: int):
|
||||||
|
return SoftmaxFunction.apply(x, dim)
|
||||||
|
|
||||||
|
|
||||||
class MaxEigLimiterFunction(torch.autograd.Function):
|
class MaxEigLimiterFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -822,7 +921,6 @@ class DoubleSwish(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def _test_max_eig():
|
def _test_max_eig():
|
||||||
|
|
||||||
for proportion in [0.1, 0.5, 10.0]:
|
for proportion in [0.1, 0.5, 10.0]:
|
||||||
logging.info(f"proportion = {proportion}")
|
logging.info(f"proportion = {proportion}")
|
||||||
x = torch.randn(100, 128)
|
x = torch.randn(100, 128)
|
||||||
@ -846,7 +944,7 @@ def _test_max_eig():
|
|||||||
y.backward(gradient=y_grad)
|
y.backward(gradient=y_grad)
|
||||||
|
|
||||||
if proportion < 0.2:
|
if proportion < 0.2:
|
||||||
assert torch.allclose(x.grad, y_grad)
|
assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
|
||||||
elif proportion > 1.0:
|
elif proportion > 1.0:
|
||||||
assert not torch.allclose(x.grad, y_grad)
|
assert not torch.allclose(x.grad, y_grad)
|
||||||
|
|
||||||
@ -957,11 +1055,24 @@ def _test_double_swish_deriv():
|
|||||||
torch.autograd.gradcheck(m, x)
|
torch.autograd.gradcheck(m, x)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_softmax():
|
||||||
|
a = torch.randn(2, 10, dtype=torch.float64)
|
||||||
|
b = a.clone()
|
||||||
|
a.requires_grad = True
|
||||||
|
b.requires_grad = True
|
||||||
|
a.softmax(dim=1)[:,0].sum().backward()
|
||||||
|
print("a grad = ", a.grad)
|
||||||
|
softmax(b, dim=1)[:,0].sum().backward()
|
||||||
|
print("b grad = ", b.grad)
|
||||||
|
assert torch.allclose(a.grad, b.grad)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
_test_softmax()
|
||||||
_test_whiten()
|
_test_whiten()
|
||||||
_test_max_eig()
|
_test_max_eig()
|
||||||
_test_activation_balancer_sign()
|
_test_activation_balancer_sign()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user