Have 2 scales on dropout

This commit is contained in:
Daniel Povey 2022-06-07 18:26:23 +08:00
parent 53ca61db7a
commit 5d24489752
2 changed files with 22 additions and 11 deletions

View File

@ -18,6 +18,7 @@
import collections
from itertools import repeat
from typing import Optional, Tuple
import logging
import random
import torch
@ -731,16 +732,21 @@ class Decorrelate(torch.nn.Module):
noise, in such a way that the self-correlation and cross-correlation
statistics match those dropout with the same `dropout_rate`
(assuming we applied the transform, e.g. if apply_prob == 1.0)
This number applies when the features are un-correlated.
dropout_max_rate: This is an upper limit, for safety, on how aggressive the
randomization can be.
eps: An epsilon used to prevent division by zero.
"""
def __init__(self,
apply_prob: float = 0.25,
dropout_rate: float = 0.1,
dropout_rate: float = 0.01,
dropout_max_rate: float = 0.1,
eps: float = 1.0e-04,
channel_dim: int = -1):
super(Decorrelate, self).__init__()
self.apply_prob = apply_prob
self.dropout_rate = dropout_rate
self.dropout_max_rate = dropout_max_rate
self.channel_dim = channel_dim
self.eps = eps
@ -783,10 +789,16 @@ class Decorrelate(torch.nn.Module):
cov = self._get_covar(x)
cov = self._normalize_covar(cov, self.eps)
avg_squared_eig = (cov**2).sum(dim=0).mean()
if random.random() < 0.001 or __name__ == "__main__":
logging.info(f"Decorrelate: avg_squared_eig = {avg_squared_eig}")
# the odd-looking formula below was obtained empirically, to match
# the self-product and cross-correlation statistics of dropout
rand_scale = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5) / avg_squared_eig
rand_scale1 = ((self.dropout_max_rate / (1.0 - self.dropout_max_rate)) ** 0.5) / avg_squared_eig
rand_scale2 = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5)
rand_scale = torch.minimum(rand_scale1, torch.tensor(rand_scale2, device=x.device))
# by multiplying by `cov`, then randomizing the sign of elements, then
# multiplying by `cov` again, we are generating something that has
@ -900,12 +912,13 @@ def _test_gauss_proj_drop():
m2.eval()
def _test_decorrelate():
logging.getLogger().setLevel(logging.INFO)
x = torch.randn(30000, 384)
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
m1 = torch.nn.Dropout(dropout_rate)
m2 = Decorrelate(apply_prob=1.0, rand_scale=dropout_rate)
m2 = Decorrelate(apply_prob=1.0, dropout_rate=dropout_rate)
for mode in ['train', 'eval']:
y1 = m1(x)
y2 = m2(x)
@ -921,9 +934,8 @@ def _test_decorrelate():
if __name__ == "__main__":
_test_decorrelate()
if False:
_test_gauss_proj_drop()
_test_activation_balancer_sign()
_test_activation_balancer_magnitude()
_test_basic_norm()
_test_double_swish_deriv()
_test_gauss_proj_drop()
_test_activation_balancer_sign()
_test_activation_balancer_magnitude()
_test_basic_norm()
_test_double_swish_deriv()

View File

@ -199,8 +199,7 @@ class ConformerEncoderLayer(nn.Module):
)
self.dropout = torch.nn.Dropout(dropout)
self.decorrelate = Decorrelate(apply_prob=0.25,
dropout_rate=0.01)
self.decorrelate = Decorrelate(apply_prob=0.25)
def forward(