From 5d244897522f00248474c6e7c2a20afc80239eaf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 7 Jun 2022 18:26:23 +0800 Subject: [PATCH] Have 2 scales on dropout --- .../pruned_transducer_stateless2/scaling.py | 30 +++++++++++++------ .../pruned_transducer_stateless5/conformer.py | 3 +- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 75a587370..87e3cbd18 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -18,6 +18,7 @@ import collections from itertools import repeat from typing import Optional, Tuple +import logging import random import torch @@ -731,16 +732,21 @@ class Decorrelate(torch.nn.Module): noise, in such a way that the self-correlation and cross-correlation statistics match those dropout with the same `dropout_rate` (assuming we applied the transform, e.g. if apply_prob == 1.0) + This number applies when the features are un-correlated. + dropout_max_rate: This is an upper limit, for safety, on how aggressive the + randomization can be. eps: An epsilon used to prevent division by zero. """ def __init__(self, apply_prob: float = 0.25, - dropout_rate: float = 0.1, + dropout_rate: float = 0.01, + dropout_max_rate: float = 0.1, eps: float = 1.0e-04, channel_dim: int = -1): super(Decorrelate, self).__init__() self.apply_prob = apply_prob self.dropout_rate = dropout_rate + self.dropout_max_rate = dropout_max_rate self.channel_dim = channel_dim self.eps = eps @@ -783,10 +789,16 @@ class Decorrelate(torch.nn.Module): cov = self._get_covar(x) cov = self._normalize_covar(cov, self.eps) avg_squared_eig = (cov**2).sum(dim=0).mean() + if random.random() < 0.001 or __name__ == "__main__": + logging.info(f"Decorrelate: avg_squared_eig = {avg_squared_eig}") # the odd-looking formula below was obtained empirically, to match # the self-product and cross-correlation statistics of dropout - rand_scale = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5) / avg_squared_eig + + rand_scale1 = ((self.dropout_max_rate / (1.0 - self.dropout_max_rate)) ** 0.5) / avg_squared_eig + rand_scale2 = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5) + rand_scale = torch.minimum(rand_scale1, torch.tensor(rand_scale2, device=x.device)) + # by multiplying by `cov`, then randomizing the sign of elements, then # multiplying by `cov` again, we are generating something that has @@ -900,12 +912,13 @@ def _test_gauss_proj_drop(): m2.eval() def _test_decorrelate(): + logging.getLogger().setLevel(logging.INFO) x = torch.randn(30000, 384) for dropout_rate in [0.2, 0.1, 0.01, 0.05]: m1 = torch.nn.Dropout(dropout_rate) - m2 = Decorrelate(apply_prob=1.0, rand_scale=dropout_rate) + m2 = Decorrelate(apply_prob=1.0, dropout_rate=dropout_rate) for mode in ['train', 'eval']: y1 = m1(x) y2 = m2(x) @@ -921,9 +934,8 @@ def _test_decorrelate(): if __name__ == "__main__": _test_decorrelate() - if False: - _test_gauss_proj_drop() - _test_activation_balancer_sign() - _test_activation_balancer_magnitude() - _test_basic_norm() - _test_double_swish_deriv() + _test_gauss_proj_drop() + _test_activation_balancer_sign() + _test_activation_balancer_magnitude() + _test_basic_norm() + _test_double_swish_deriv() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index baf1441c3..142fa34d8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -199,8 +199,7 @@ class ConformerEncoderLayer(nn.Module): ) self.dropout = torch.nn.Dropout(dropout) - self.decorrelate = Decorrelate(apply_prob=0.25, - dropout_rate=0.01) + self.decorrelate = Decorrelate(apply_prob=0.25) def forward(