Randomize the projections

This commit is contained in:
Daniel Povey 2022-06-06 16:05:18 +08:00
parent 6fdb356315
commit 31848dcd11
2 changed files with 18 additions and 7 deletions

View File

@ -19,6 +19,7 @@ import collections
from itertools import repeat from itertools import repeat
from typing import Optional, Tuple from typing import Optional, Tuple
import random
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
@ -639,7 +640,7 @@ class ScaledEmbedding(nn.Module):
return s.format(**self.__dict__) return s.format(**self.__dict__)
class FixedProjDrop(torch.nn.Module): class RandProjDrop(torch.nn.Module):
""" """
This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions. This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions.
The directions of dropout are fixed when the class is initialized, and are orthogonal. The directions of dropout are fixed when the class is initialized, and are orthogonal.
@ -651,7 +652,7 @@ class FixedProjDrop(torch.nn.Module):
num_channels: int, num_channels: int,
dropout_rate: float = 0.1, dropout_rate: float = 0.1,
channel_dim: int = -1): channel_dim: int = -1):
super(FixedProjDrop, self).__init__() super(RandProjDrop, self).__init__()
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.channel_dim = channel_dim self.channel_dim = channel_dim
@ -665,6 +666,7 @@ class FixedProjDrop(torch.nn.Module):
x = torch.nn.functional.dropout(x, self.dropout_rate, x = torch.nn.functional.dropout(x, self.dropout_rate,
training=False) training=False)
else: else:
self._randomize_U()
x = x.transpose(self.channel_dim, -1) # (..., num_channels) x = x.transpose(self.channel_dim, -1) # (..., num_channels)
x = torch.matmul(x, self.U) x = torch.matmul(x, self.U)
@ -674,7 +676,16 @@ class FixedProjDrop(torch.nn.Module):
x = x.transpose(self.channel_dim, -1) # (..., num_channels) x = x.transpose(self.channel_dim, -1) # (..., num_channels)
return x return x
def _randomize_U(self):
dim = random.randint(0, 1)
U = self.U
num_channels = U.shape[0]
# pick place to split U in two pieces.
r = random.randint(1, num_channels - 2)
U_part1 = U.narrow(dim, 0, r)
U_part2 = U.narrow(dim, r, num_channels-r)
U = torch.cat((U_part2, U_part1), dim=dim)
self.U[:] = U
def _test_activation_balancer_sign(): def _test_activation_balancer_sign():
probs = torch.arange(0, 1, 0.01) probs = torch.arange(0, 1, 0.01)
@ -751,7 +762,7 @@ def _test_double_swish_deriv():
def _test_proj_drop(): def _test_proj_drop():
x = torch.randn(30000, 300) x = torch.randn(30000, 300)
m = FixedProjDrop(300, 0.1) m = RandProjDrop(300, 0.1)
y = m(x) y = m(x)
xmag = (x*x).mean() xmag = (x*x).mean()
ymag = (y*y).mean() ymag = (y*y).mean()

View File

@ -29,7 +29,7 @@ from scaling import (
ScaledConv1d, ScaledConv1d,
ScaledConv2d, ScaledConv2d,
ScaledLinear, ScaledLinear,
FixedProjDrop, RandProjDrop,
) )
from torch import Tensor, nn from torch import Tensor, nn
@ -197,7 +197,7 @@ class ConformerEncoderLayer(nn.Module):
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
) )
self.dropout = FixedProjDrop(d_model, dropout) self.dropout = RandProjDrop(d_model, dropout)
def forward( def forward(
self, self,
@ -369,7 +369,7 @@ class RelPositionalEncoding(torch.nn.Module):
"""Construct an PositionalEncoding object.""" """Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__() super(RelPositionalEncoding, self).__init__()
self.d_model = d_model self.d_model = d_model
self.dropout = FixedProjDrop(d_model, dropout_rate) self.dropout = RandProjDrop(d_model, dropout_rate)
self.pe = None self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len)) self.extend_pe(torch.tensor(0.0).expand(1, max_len))