mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 09:04:19 +00:00
Randomize the projections
This commit is contained in:
parent
6fdb356315
commit
31848dcd11
@ -19,6 +19,7 @@ import collections
|
|||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import random
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -639,7 +640,7 @@ class ScaledEmbedding(nn.Module):
|
|||||||
return s.format(**self.__dict__)
|
return s.format(**self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
class FixedProjDrop(torch.nn.Module):
|
class RandProjDrop(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions.
|
This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions.
|
||||||
The directions of dropout are fixed when the class is initialized, and are orthogonal.
|
The directions of dropout are fixed when the class is initialized, and are orthogonal.
|
||||||
@ -651,7 +652,7 @@ class FixedProjDrop(torch.nn.Module):
|
|||||||
num_channels: int,
|
num_channels: int,
|
||||||
dropout_rate: float = 0.1,
|
dropout_rate: float = 0.1,
|
||||||
channel_dim: int = -1):
|
channel_dim: int = -1):
|
||||||
super(FixedProjDrop, self).__init__()
|
super(RandProjDrop, self).__init__()
|
||||||
self.dropout_rate = dropout_rate
|
self.dropout_rate = dropout_rate
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
|
|
||||||
@ -665,6 +666,7 @@ class FixedProjDrop(torch.nn.Module):
|
|||||||
x = torch.nn.functional.dropout(x, self.dropout_rate,
|
x = torch.nn.functional.dropout(x, self.dropout_rate,
|
||||||
training=False)
|
training=False)
|
||||||
else:
|
else:
|
||||||
|
self._randomize_U()
|
||||||
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
||||||
|
|
||||||
x = torch.matmul(x, self.U)
|
x = torch.matmul(x, self.U)
|
||||||
@ -674,7 +676,16 @@ class FixedProjDrop(torch.nn.Module):
|
|||||||
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def _randomize_U(self):
|
||||||
|
dim = random.randint(0, 1)
|
||||||
|
U = self.U
|
||||||
|
num_channels = U.shape[0]
|
||||||
|
# pick place to split U in two pieces.
|
||||||
|
r = random.randint(1, num_channels - 2)
|
||||||
|
U_part1 = U.narrow(dim, 0, r)
|
||||||
|
U_part2 = U.narrow(dim, r, num_channels-r)
|
||||||
|
U = torch.cat((U_part2, U_part1), dim=dim)
|
||||||
|
self.U[:] = U
|
||||||
|
|
||||||
def _test_activation_balancer_sign():
|
def _test_activation_balancer_sign():
|
||||||
probs = torch.arange(0, 1, 0.01)
|
probs = torch.arange(0, 1, 0.01)
|
||||||
@ -751,7 +762,7 @@ def _test_double_swish_deriv():
|
|||||||
|
|
||||||
def _test_proj_drop():
|
def _test_proj_drop():
|
||||||
x = torch.randn(30000, 300)
|
x = torch.randn(30000, 300)
|
||||||
m = FixedProjDrop(300, 0.1)
|
m = RandProjDrop(300, 0.1)
|
||||||
y = m(x)
|
y = m(x)
|
||||||
xmag = (x*x).mean()
|
xmag = (x*x).mean()
|
||||||
ymag = (y*y).mean()
|
ymag = (y*y).mean()
|
||||||
|
@ -29,7 +29,7 @@ from scaling import (
|
|||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
ScaledConv2d,
|
ScaledConv2d,
|
||||||
ScaledLinear,
|
ScaledLinear,
|
||||||
FixedProjDrop,
|
RandProjDrop,
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -197,7 +197,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dropout = FixedProjDrop(d_model, dropout)
|
self.dropout = RandProjDrop(d_model, dropout)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -369,7 +369,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
"""Construct an PositionalEncoding object."""
|
"""Construct an PositionalEncoding object."""
|
||||||
super(RelPositionalEncoding, self).__init__()
|
super(RelPositionalEncoding, self).__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.dropout = FixedProjDrop(d_model, dropout_rate)
|
self.dropout = RandProjDrop(d_model, dropout_rate)
|
||||||
self.pe = None
|
self.pe = None
|
||||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user