Have 2 scales on dropout

This commit is contained in:
Daniel Povey 2022-06-07 18:26:23 +08:00
parent 53ca61db7a
commit 5d24489752
2 changed files with 22 additions and 11 deletions

View File

@ -18,6 +18,7 @@
import collections import collections
from itertools import repeat from itertools import repeat
from typing import Optional, Tuple from typing import Optional, Tuple
import logging
import random import random
import torch import torch
@ -731,16 +732,21 @@ class Decorrelate(torch.nn.Module):
noise, in such a way that the self-correlation and cross-correlation noise, in such a way that the self-correlation and cross-correlation
statistics match those dropout with the same `dropout_rate` statistics match those dropout with the same `dropout_rate`
(assuming we applied the transform, e.g. if apply_prob == 1.0) (assuming we applied the transform, e.g. if apply_prob == 1.0)
This number applies when the features are un-correlated.
dropout_max_rate: This is an upper limit, for safety, on how aggressive the
randomization can be.
eps: An epsilon used to prevent division by zero. eps: An epsilon used to prevent division by zero.
""" """
def __init__(self, def __init__(self,
apply_prob: float = 0.25, apply_prob: float = 0.25,
dropout_rate: float = 0.1, dropout_rate: float = 0.01,
dropout_max_rate: float = 0.1,
eps: float = 1.0e-04, eps: float = 1.0e-04,
channel_dim: int = -1): channel_dim: int = -1):
super(Decorrelate, self).__init__() super(Decorrelate, self).__init__()
self.apply_prob = apply_prob self.apply_prob = apply_prob
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.dropout_max_rate = dropout_max_rate
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.eps = eps self.eps = eps
@ -783,10 +789,16 @@ class Decorrelate(torch.nn.Module):
cov = self._get_covar(x) cov = self._get_covar(x)
cov = self._normalize_covar(cov, self.eps) cov = self._normalize_covar(cov, self.eps)
avg_squared_eig = (cov**2).sum(dim=0).mean() avg_squared_eig = (cov**2).sum(dim=0).mean()
if random.random() < 0.001 or __name__ == "__main__":
logging.info(f"Decorrelate: avg_squared_eig = {avg_squared_eig}")
# the odd-looking formula below was obtained empirically, to match # the odd-looking formula below was obtained empirically, to match
# the self-product and cross-correlation statistics of dropout # the self-product and cross-correlation statistics of dropout
rand_scale = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5) / avg_squared_eig
rand_scale1 = ((self.dropout_max_rate / (1.0 - self.dropout_max_rate)) ** 0.5) / avg_squared_eig
rand_scale2 = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5)
rand_scale = torch.minimum(rand_scale1, torch.tensor(rand_scale2, device=x.device))
# by multiplying by `cov`, then randomizing the sign of elements, then # by multiplying by `cov`, then randomizing the sign of elements, then
# multiplying by `cov` again, we are generating something that has # multiplying by `cov` again, we are generating something that has
@ -900,12 +912,13 @@ def _test_gauss_proj_drop():
m2.eval() m2.eval()
def _test_decorrelate(): def _test_decorrelate():
logging.getLogger().setLevel(logging.INFO)
x = torch.randn(30000, 384) x = torch.randn(30000, 384)
for dropout_rate in [0.2, 0.1, 0.01, 0.05]: for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
m1 = torch.nn.Dropout(dropout_rate) m1 = torch.nn.Dropout(dropout_rate)
m2 = Decorrelate(apply_prob=1.0, rand_scale=dropout_rate) m2 = Decorrelate(apply_prob=1.0, dropout_rate=dropout_rate)
for mode in ['train', 'eval']: for mode in ['train', 'eval']:
y1 = m1(x) y1 = m1(x)
y2 = m2(x) y2 = m2(x)
@ -921,9 +934,8 @@ def _test_decorrelate():
if __name__ == "__main__": if __name__ == "__main__":
_test_decorrelate() _test_decorrelate()
if False: _test_gauss_proj_drop()
_test_gauss_proj_drop() _test_activation_balancer_sign()
_test_activation_balancer_sign() _test_activation_balancer_magnitude()
_test_activation_balancer_magnitude() _test_basic_norm()
_test_basic_norm() _test_double_swish_deriv()
_test_double_swish_deriv()

View File

@ -199,8 +199,7 @@ class ConformerEncoderLayer(nn.Module):
) )
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
self.decorrelate = Decorrelate(apply_prob=0.25, self.decorrelate = Decorrelate(apply_prob=0.25)
dropout_rate=0.01)
def forward( def forward(