mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Have 2 scales on dropout
This commit is contained in:
parent
53ca61db7a
commit
5d24489752
@ -18,6 +18,7 @@
|
||||
import collections
|
||||
from itertools import repeat
|
||||
from typing import Optional, Tuple
|
||||
import logging
|
||||
|
||||
import random
|
||||
import torch
|
||||
@ -731,16 +732,21 @@ class Decorrelate(torch.nn.Module):
|
||||
noise, in such a way that the self-correlation and cross-correlation
|
||||
statistics match those dropout with the same `dropout_rate`
|
||||
(assuming we applied the transform, e.g. if apply_prob == 1.0)
|
||||
This number applies when the features are un-correlated.
|
||||
dropout_max_rate: This is an upper limit, for safety, on how aggressive the
|
||||
randomization can be.
|
||||
eps: An epsilon used to prevent division by zero.
|
||||
"""
|
||||
def __init__(self,
|
||||
apply_prob: float = 0.25,
|
||||
dropout_rate: float = 0.1,
|
||||
dropout_rate: float = 0.01,
|
||||
dropout_max_rate: float = 0.1,
|
||||
eps: float = 1.0e-04,
|
||||
channel_dim: int = -1):
|
||||
super(Decorrelate, self).__init__()
|
||||
self.apply_prob = apply_prob
|
||||
self.dropout_rate = dropout_rate
|
||||
self.dropout_max_rate = dropout_max_rate
|
||||
self.channel_dim = channel_dim
|
||||
self.eps = eps
|
||||
|
||||
@ -783,10 +789,16 @@ class Decorrelate(torch.nn.Module):
|
||||
cov = self._get_covar(x)
|
||||
cov = self._normalize_covar(cov, self.eps)
|
||||
avg_squared_eig = (cov**2).sum(dim=0).mean()
|
||||
if random.random() < 0.001 or __name__ == "__main__":
|
||||
logging.info(f"Decorrelate: avg_squared_eig = {avg_squared_eig}")
|
||||
|
||||
# the odd-looking formula below was obtained empirically, to match
|
||||
# the self-product and cross-correlation statistics of dropout
|
||||
rand_scale = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5) / avg_squared_eig
|
||||
|
||||
rand_scale1 = ((self.dropout_max_rate / (1.0 - self.dropout_max_rate)) ** 0.5) / avg_squared_eig
|
||||
rand_scale2 = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5)
|
||||
rand_scale = torch.minimum(rand_scale1, torch.tensor(rand_scale2, device=x.device))
|
||||
|
||||
|
||||
# by multiplying by `cov`, then randomizing the sign of elements, then
|
||||
# multiplying by `cov` again, we are generating something that has
|
||||
@ -900,12 +912,13 @@ def _test_gauss_proj_drop():
|
||||
m2.eval()
|
||||
|
||||
def _test_decorrelate():
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
x = torch.randn(30000, 384)
|
||||
|
||||
|
||||
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
||||
m1 = torch.nn.Dropout(dropout_rate)
|
||||
m2 = Decorrelate(apply_prob=1.0, rand_scale=dropout_rate)
|
||||
m2 = Decorrelate(apply_prob=1.0, dropout_rate=dropout_rate)
|
||||
for mode in ['train', 'eval']:
|
||||
y1 = m1(x)
|
||||
y2 = m2(x)
|
||||
@ -921,9 +934,8 @@ def _test_decorrelate():
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_decorrelate()
|
||||
if False:
|
||||
_test_gauss_proj_drop()
|
||||
_test_activation_balancer_sign()
|
||||
_test_activation_balancer_magnitude()
|
||||
_test_basic_norm()
|
||||
_test_double_swish_deriv()
|
||||
_test_gauss_proj_drop()
|
||||
_test_activation_balancer_sign()
|
||||
_test_activation_balancer_magnitude()
|
||||
_test_basic_norm()
|
||||
_test_double_swish_deriv()
|
||||
|
@ -199,8 +199,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.decorrelate = Decorrelate(apply_prob=0.25,
|
||||
dropout_rate=0.01)
|
||||
self.decorrelate = Decorrelate(apply_prob=0.25)
|
||||
|
||||
|
||||
def forward(
|
||||
|
Loading…
x
Reference in New Issue
Block a user