from local

This commit is contained in:
dohe0342 2023-03-13 13:19:13 +09:00
parent 01aad91538
commit 392f91b2f3
2 changed files with 234 additions and 0 deletions

Binary file not shown.

View File

@ -1595,6 +1595,240 @@ class Conv2dSubsampling(nn.Module):
return x return x
ACTIVATIONS = {
'relu': torch.nn.ReLU,
'leaky_relu': torch.nn.LeakyReLU,
'gelu': torch.nn.GELU,
}
class FCLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, activation='gelu'):
super(FCLayer, self).__init__()
in_channels = int(in_channels)
out_channels = int(out_channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.linear = torch.nn.Linear(in_channels, out_channels)
self.ins = torch.nn.InstanceNorm1d(out_channels, affine=True)
self.act = ACTIVATIONS[activation]()
def forward(self, x):
out = self.linear(x)
out = out.transpose(1,2)
out = self.ins(out)
out = self.act(out)
out = out.transpose(1,2)
return out
class ConvLayer2(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, groups=1):
super(ConvLayer2, self).__init__()
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad1d(reflection_padding)
self.conv1d = ScaledConv1d(
in_channels, out_channels, kernel_size, stride, groups=groups)
def forward(self, x, pad=False):
if pad:
out = self.reflection_pad(x)
else:
out = x
out = self.conv1d(out)
return out
class ResidualBlock2(torch.nn.Module):
"""ResidualBlock
introduced in: https://arxiv.org/abs/1512.03385
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
"""
def __init__(self, channels, activation='gelu'):
super(ResidualBlock2, self).__init__()
self.conv1 = nn.Conv1d(channels, channels, kernel_size=2)
self.in1 = nn.InstanceNorm1d(channels, affine=True)
self.conv2 = nn.Conv1d(channels, channels, kernel_size=2)
self.in2 = nn.InstanceNorm1d(channels, affine=True)
self.reflection_pad = torch.nn.ReflectionPad1d(1)
self.act = ACTIVATIONS[activation]()
def forward(self, x):
residual = x
out = self.act(self.in1(self.conv1(self.reflection_pad(x))))
out = self.in2(self.conv2(out))
out = out + residual
return out
class ViewMaker1(BaseFairseqModel):
'''Viewmaker network that stochastically maps a multichannel 2D input to an output of the same size.'''
def __init__(self, num_channels=512, distortion_budget=0.02, activation='gelu',
clamp=True, frequency_domain=False, downsample_to=False, num_res_blocks=0, num_noise=0):
'''Initialize the Viewmaker network.
Args:
num_channels: Number of channels in the input (e.g. 1 for speech, 3 for images)
Input will have shape [batch_size, num_channels, height, width]
distortion_budget: Distortion budget of the viewmaker (epsilon, in the paper).
Controls how strong the perturbations can be.
activation: The activation function used in the network ('relu' and 'leaky_relu' currently supported)
clamp: Whether to clamp the outputs to [0, 1] (useful to ensure output is, e.g., a valid image)
frequency_domain: Whether to apply perturbation (and distortion budget) in the frequency domain.
This is useful for shifting the inductive bias of the viewmaker towards more global / textural views.
downsample_to: Downsamples the image, applies viewmaker, then upsamples. Possibly useful for
higher-resolution inputs, but not evaluaed in the paper.
num_res_blocks: Number of residual blocks to use in the network.
'''
super().__init__()
self.num_channels = num_channels
self.num_res_blocks = num_res_blocks
self.activation = activation
self.clamp = clamp
self.frequency_domain = frequency_domain
self.downsample_to = downsample_to
self.distortion_budget = distortion_budget
self.num_noise = num_noise
self.act = ACTIVATIONS[activation]()
# Initial convolution layers (+ 1 for noise filter)
self.conv1 = ConvLayer2(self.num_channels + self.num_noise, \
self.num_channels, kernel_size=2, stride=1)
self.in1 = torch.nn.InstanceNorm1d(self.num_channels, affine=True)
self.conv2 = ConvLayer2(self.num_channels, self.num_channels, kernel_size=2, stride=1)
self.in2 = torch.nn.InstanceNorm1d(self.num_channels, affine=True)
self.conv3 = ConvLayer2(self.num_channels, self.num_channels, kernel_size=2, stride=1)
self.in3 = torch.nn.InstanceNorm1d(self.num_channels, affine=True)
self.conv4 = ConvLayer2(self.num_channels, self.num_channels, kernel_size=2, stride=1)
self.in4 = torch.nn.InstanceNorm1d(self.num_channels, affine=True)
# Residual layers have +N for added random channels
if not self.num_noise:
self.res1 = ResidualBlock2(self.num_channels + 1)
self.res2 = ResidualBlock2(self.num_channels + 2)
self.res3 = ResidualBlock2(self.num_channels + 3)
self.res4 = ResidualBlock2(self.num_channels + 4)
self.res5 = ResidualBlock2(self.num_channels + 5)
self.conv5 = ConvLayer2(self.num_channels+self.num_res_blocks, \
self.num_channels, kernel_size=2, stride=1)
else:
self.res1 = ResidualBlock2(self.num_channels)
self.res2 = ResidualBlock2(self.num_channels)
self.res3 = ResidualBlock2(self.num_channels)
self.res4 = ResidualBlock2(self.num_channels)
self.res5 = ResidualBlock2(self.num_channels)
self.conv5 = ConvLayer2(self.num_channels, \
self.num_channels, kernel_size=2, stride=1)
self.ins5 = torch.nn.InstanceNorm1d(self.num_channels, affine=True)
self.conv6 = ConvLayer2(self.num_channels, self.num_channels, kernel_size=2, stride=1)
@staticmethod
def zero_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
# actual 0 has symmetry problems
init.normal_(m.weight.data, mean=0, std=1e-4)
# init.constant_(m.weight.data, 0)
init.constant_(m.bias.data, 0)
elif isinstance(m, nn.BatchNorm1d):
pass
def add_noise_channel(self, x, num=1, bound_multiplier=1):
# bound_multiplier is a scalar or a 1D tensor of length batch_size
batch_size = x.size(0)
filter_size = x.size(-1)
shp = (batch_size, num, filter_size)
bound_multiplier = torch.tensor(bound_multiplier, device=x.device)
noise = torch.rand(shp, device=x.device) * bound_multiplier.view(-1, 1, 1)
#if x.dtype == 'torch.cuda.float16':
# print('fuck'*100)
#noise.type(torch.cuda.float16)
noise = noise.half()
return torch.cat((x, noise), dim=1)
def basic_net(self, y, num_res_blocks=3, bound_multiplier=1):
if self.num_noise:
y = self.add_noise_channel(y, num=self.num_noise, bound_multiplier=bound_multiplier)
y = self.act(self.in1(self.conv1(y)))
y = self.act(self.in2(self.conv2(y, pad=True)))
y = self.act(self.in3(self.conv3(y)))
y = self.act(self.in4(self.conv4(y, pad=True)))
# Features that could be useful for other auxilary layers / losses.
# [batch_size, 128]
features = y.clone().mean([-1, -2])
for i, res in enumerate([self.res1, self.res2, self.res3, self.res4, self.res5]):
if i < num_res_blocks:
if not self.num_noise:
y = res(self.add_noise_channel(y, bound_multiplier=bound_multiplier))
else:
y = res(y)
y = self.act(self.ins5(self.conv5(y, pad=True)))
y = self.conv6(y)
return y, features
def get_delta(self, y_pixels, eps=1e-4):
'''Constrains the input perturbation by projecting it onto an L1 sphere'''
distortion_budget = self.distortion_budget
delta = torch.tanh(y_pixels) # Project to [-1, 1]
avg_magnitude = delta.abs().mean([1,2], keepdim=True)
max_magnitude = distortion_budget
delta = delta * max_magnitude / (avg_magnitude + eps)
return delta
def get_delta2(self, y_pixels, padding_mask, eps=1e-4):
'''Constrains the input perturbation by projecting it onto an L1 sphere'''
if padding_mask is not None:
padding_mask_ = torch.logical_not(padding_mask)
padding_mask_ = padding_mask_.long().unsqueeze(2)
y_pixels = y_pixels.transpose(1,2)
y_pixels *= padding_mask_
y_pixels = y_pixels.transpose(1,2)
distortion_budget = self.distortion_budget
delta = torch.tanh(y_pixels) # Project to [-1, 1]
avg_magnitude = delta.abs().mean([1,2], keepdim=True)
max_magnitude = distortion_budget
delta = delta * max_magnitude / (avg_magnitude + eps)
return delta
def forward(self, x, padding_mask):
x = x.transpose(1,2)
if self.downsample_to:
# Downsample.
x_orig = x
x = torch.nn.functional.interpolate(
x, size=(self.downsample_to, self.downsample_to), mode='bilinear')
y = x
if self.frequency_domain and 0:
# Input to viewmaker is in frequency domain, outputs frequency domain perturbation.
# Uses the Discrete Cosine Transform.
# shape still [batch_size, C, W, H]
y = dct.dct_2d(y)
y_pixels, features = self.basic_net(y, self.num_res_blocks, bound_multiplier=1)
#delta = self.get_delta(y_pixels.clone())
delta = self.get_delta2(y_pixels.clone(), padding_mask)
# Additive perturbation
#result = x + delta
result = y_pixels
delta = delta.transpose(1,2)
result = result.transpose(1,2)
return result, delta
if __name__ == "__main__": if __name__ == "__main__":
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)