mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
Have 2 scales on dropout
This commit is contained in:
parent
53ca61db7a
commit
5d24489752
@ -18,6 +18,7 @@
|
|||||||
import collections
|
import collections
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
import logging
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import torch
|
import torch
|
||||||
@ -731,16 +732,21 @@ class Decorrelate(torch.nn.Module):
|
|||||||
noise, in such a way that the self-correlation and cross-correlation
|
noise, in such a way that the self-correlation and cross-correlation
|
||||||
statistics match those dropout with the same `dropout_rate`
|
statistics match those dropout with the same `dropout_rate`
|
||||||
(assuming we applied the transform, e.g. if apply_prob == 1.0)
|
(assuming we applied the transform, e.g. if apply_prob == 1.0)
|
||||||
|
This number applies when the features are un-correlated.
|
||||||
|
dropout_max_rate: This is an upper limit, for safety, on how aggressive the
|
||||||
|
randomization can be.
|
||||||
eps: An epsilon used to prevent division by zero.
|
eps: An epsilon used to prevent division by zero.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
apply_prob: float = 0.25,
|
apply_prob: float = 0.25,
|
||||||
dropout_rate: float = 0.1,
|
dropout_rate: float = 0.01,
|
||||||
|
dropout_max_rate: float = 0.1,
|
||||||
eps: float = 1.0e-04,
|
eps: float = 1.0e-04,
|
||||||
channel_dim: int = -1):
|
channel_dim: int = -1):
|
||||||
super(Decorrelate, self).__init__()
|
super(Decorrelate, self).__init__()
|
||||||
self.apply_prob = apply_prob
|
self.apply_prob = apply_prob
|
||||||
self.dropout_rate = dropout_rate
|
self.dropout_rate = dropout_rate
|
||||||
|
self.dropout_max_rate = dropout_max_rate
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
@ -783,10 +789,16 @@ class Decorrelate(torch.nn.Module):
|
|||||||
cov = self._get_covar(x)
|
cov = self._get_covar(x)
|
||||||
cov = self._normalize_covar(cov, self.eps)
|
cov = self._normalize_covar(cov, self.eps)
|
||||||
avg_squared_eig = (cov**2).sum(dim=0).mean()
|
avg_squared_eig = (cov**2).sum(dim=0).mean()
|
||||||
|
if random.random() < 0.001 or __name__ == "__main__":
|
||||||
|
logging.info(f"Decorrelate: avg_squared_eig = {avg_squared_eig}")
|
||||||
|
|
||||||
# the odd-looking formula below was obtained empirically, to match
|
# the odd-looking formula below was obtained empirically, to match
|
||||||
# the self-product and cross-correlation statistics of dropout
|
# the self-product and cross-correlation statistics of dropout
|
||||||
rand_scale = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5) / avg_squared_eig
|
|
||||||
|
rand_scale1 = ((self.dropout_max_rate / (1.0 - self.dropout_max_rate)) ** 0.5) / avg_squared_eig
|
||||||
|
rand_scale2 = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5)
|
||||||
|
rand_scale = torch.minimum(rand_scale1, torch.tensor(rand_scale2, device=x.device))
|
||||||
|
|
||||||
|
|
||||||
# by multiplying by `cov`, then randomizing the sign of elements, then
|
# by multiplying by `cov`, then randomizing the sign of elements, then
|
||||||
# multiplying by `cov` again, we are generating something that has
|
# multiplying by `cov` again, we are generating something that has
|
||||||
@ -900,12 +912,13 @@ def _test_gauss_proj_drop():
|
|||||||
m2.eval()
|
m2.eval()
|
||||||
|
|
||||||
def _test_decorrelate():
|
def _test_decorrelate():
|
||||||
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
x = torch.randn(30000, 384)
|
x = torch.randn(30000, 384)
|
||||||
|
|
||||||
|
|
||||||
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
||||||
m1 = torch.nn.Dropout(dropout_rate)
|
m1 = torch.nn.Dropout(dropout_rate)
|
||||||
m2 = Decorrelate(apply_prob=1.0, rand_scale=dropout_rate)
|
m2 = Decorrelate(apply_prob=1.0, dropout_rate=dropout_rate)
|
||||||
for mode in ['train', 'eval']:
|
for mode in ['train', 'eval']:
|
||||||
y1 = m1(x)
|
y1 = m1(x)
|
||||||
y2 = m2(x)
|
y2 = m2(x)
|
||||||
@ -921,9 +934,8 @@ def _test_decorrelate():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
_test_decorrelate()
|
_test_decorrelate()
|
||||||
if False:
|
_test_gauss_proj_drop()
|
||||||
_test_gauss_proj_drop()
|
_test_activation_balancer_sign()
|
||||||
_test_activation_balancer_sign()
|
_test_activation_balancer_magnitude()
|
||||||
_test_activation_balancer_magnitude()
|
_test_basic_norm()
|
||||||
_test_basic_norm()
|
_test_double_swish_deriv()
|
||||||
_test_double_swish_deriv()
|
|
||||||
|
@ -199,8 +199,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.dropout = torch.nn.Dropout(dropout)
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
self.decorrelate = Decorrelate(apply_prob=0.25,
|
self.decorrelate = Decorrelate(apply_prob=0.25)
|
||||||
dropout_rate=0.01)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user