Add ProjDrop for axis-independent dropout

This commit is contained in:
Daniel Povey 2022-06-05 22:59:10 +08:00
parent 8a3068ead8
commit 136ffb0597
2 changed files with 53 additions and 2 deletions

View File

@ -639,6 +639,45 @@ class ScaledEmbedding(nn.Module):
return s.format(**self.__dict__)
class ProjDrop(torch.nn.Module):
"""
This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions.
dropout_rate: the dropout probability (actually will define the number of zeroed-out directions)
channel_dim: the axis corresponding to the channel, e.g. -1, 0, 1, 2.
"""
def __init__(self,
dropout_rate: float = 0.1,
channel_dim: int = -1):
super(ProjDrop, self).__init__()
self.dropout_rate = dropout_rate
self.channel_dim = channel_dim
def forward(self, x: Tensor) -> Tensor:
if self.training:
return x * (1.0 - self.dropout_rate)
else:
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
num_channels = x.shape[-1]
num_dropped = int(self.dropout_rate * num_channels)
r = torch.randn(num_dropped, num_channels,
device=x.device, dtype=x.dtype)
rr = torch.matmul(r, r.t()) # num_dropped by num_dropped
rr += 0.01 # to 100% ensure it is invertible
rr_inv = rr.cholesky.cholesky_inverse()
# OK, so r rr_inv r.t() will have eigenvalues of 1.
xr = torch.matmul(x, r.t()) # (..., num_dropped)
rr_inv_r = torch.matmul(rr_inv, r) # (num_dropped, num_channels)
xrr = torch.matmul(xr, rr_inv_r) # (..., num_channels)
return x - xrr
def _test_activation_balancer_sign():
probs = torch.arange(0, 1, 0.01)
N = 1000
@ -712,7 +751,18 @@ def _test_double_swish_deriv():
torch.autograd.gradcheck(m, x)
def _test_proj_drop():
x = torch.randn(3000, 300)
m = ProjDrop(0.1)
y = m(x)
xmag = (x*x).sqrt().mean()
ymag = (y*y).sqrt().mean()
print(f"xmag = {xmag}, ymag = {ymag}")
assert abs((ymag / xmag) - 0.9) < 0.01
if __name__ == "__main__":
_test_proj_drop()
_test_activation_balancer_sign()
_test_activation_balancer_magnitude()
_test_basic_norm()

View File

@ -29,6 +29,7 @@ from scaling import (
ScaledConv1d,
ScaledConv2d,
ScaledLinear,
ProjDrop,
)
from torch import Tensor, nn
@ -196,7 +197,7 @@ class ConformerEncoderLayer(nn.Module):
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
)
self.dropout = nn.Dropout(dropout)
self.dropout = ProjDrop(dropout)
def forward(
self,
@ -368,7 +369,7 @@ class RelPositionalEncoding(torch.nn.Module):
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.dropout = ProjDrop(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))