Randomize the projections

This commit is contained in:
Daniel Povey 2022-06-06 16:05:18 +08:00
parent 6fdb356315
commit 31848dcd11
2 changed files with 18 additions and 7 deletions

View File

@ -19,6 +19,7 @@ import collections
from itertools import repeat
from typing import Optional, Tuple
import random
import torch
import torch.nn as nn
from torch import Tensor
@ -639,7 +640,7 @@ class ScaledEmbedding(nn.Module):
return s.format(**self.__dict__)
class FixedProjDrop(torch.nn.Module):
class RandProjDrop(torch.nn.Module):
"""
This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions.
The directions of dropout are fixed when the class is initialized, and are orthogonal.
@ -651,7 +652,7 @@ class FixedProjDrop(torch.nn.Module):
num_channels: int,
dropout_rate: float = 0.1,
channel_dim: int = -1):
super(FixedProjDrop, self).__init__()
super(RandProjDrop, self).__init__()
self.dropout_rate = dropout_rate
self.channel_dim = channel_dim
@ -665,6 +666,7 @@ class FixedProjDrop(torch.nn.Module):
x = torch.nn.functional.dropout(x, self.dropout_rate,
training=False)
else:
self._randomize_U()
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
x = torch.matmul(x, self.U)
@ -674,7 +676,16 @@ class FixedProjDrop(torch.nn.Module):
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
return x
def _randomize_U(self):
dim = random.randint(0, 1)
U = self.U
num_channels = U.shape[0]
# pick place to split U in two pieces.
r = random.randint(1, num_channels - 2)
U_part1 = U.narrow(dim, 0, r)
U_part2 = U.narrow(dim, r, num_channels-r)
U = torch.cat((U_part2, U_part1), dim=dim)
self.U[:] = U
def _test_activation_balancer_sign():
probs = torch.arange(0, 1, 0.01)
@ -751,7 +762,7 @@ def _test_double_swish_deriv():
def _test_proj_drop():
x = torch.randn(30000, 300)
m = FixedProjDrop(300, 0.1)
m = RandProjDrop(300, 0.1)
y = m(x)
xmag = (x*x).mean()
ymag = (y*y).mean()

View File

@ -29,7 +29,7 @@ from scaling import (
ScaledConv1d,
ScaledConv2d,
ScaledLinear,
FixedProjDrop,
RandProjDrop,
)
from torch import Tensor, nn
@ -197,7 +197,7 @@ class ConformerEncoderLayer(nn.Module):
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
)
self.dropout = FixedProjDrop(d_model, dropout)
self.dropout = RandProjDrop(d_model, dropout)
def forward(
self,
@ -369,7 +369,7 @@ class RelPositionalEncoding(torch.nn.Module):
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.dropout = FixedProjDrop(d_model, dropout_rate)
self.dropout = RandProjDrop(d_model, dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))