mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
01aad91538
commit
392f91b2f3
BIN
egs/librispeech/ASR/.prepare.sh.swp
Normal file
BIN
egs/librispeech/ASR/.prepare.sh.swp
Normal file
Binary file not shown.
@ -1595,6 +1595,240 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
ACTIVATIONS = {
|
||||||
|
'relu': torch.nn.ReLU,
|
||||||
|
'leaky_relu': torch.nn.LeakyReLU,
|
||||||
|
'gelu': torch.nn.GELU,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FCLayer(torch.nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, activation='gelu'):
|
||||||
|
super(FCLayer, self).__init__()
|
||||||
|
in_channels = int(in_channels)
|
||||||
|
out_channels = int(out_channels)
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.linear = torch.nn.Linear(in_channels, out_channels)
|
||||||
|
self.ins = torch.nn.InstanceNorm1d(out_channels, affine=True)
|
||||||
|
self.act = ACTIVATIONS[activation]()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.linear(x)
|
||||||
|
out = out.transpose(1,2)
|
||||||
|
out = self.ins(out)
|
||||||
|
out = self.act(out)
|
||||||
|
out = out.transpose(1,2)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ConvLayer2(torch.nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, groups=1):
|
||||||
|
super(ConvLayer2, self).__init__()
|
||||||
|
reflection_padding = kernel_size // 2
|
||||||
|
self.reflection_pad = torch.nn.ReflectionPad1d(reflection_padding)
|
||||||
|
self.conv1d = ScaledConv1d(
|
||||||
|
in_channels, out_channels, kernel_size, stride, groups=groups)
|
||||||
|
|
||||||
|
def forward(self, x, pad=False):
|
||||||
|
if pad:
|
||||||
|
out = self.reflection_pad(x)
|
||||||
|
else:
|
||||||
|
out = x
|
||||||
|
out = self.conv1d(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock2(torch.nn.Module):
|
||||||
|
"""ResidualBlock
|
||||||
|
introduced in: https://arxiv.org/abs/1512.03385
|
||||||
|
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, activation='gelu'):
|
||||||
|
super(ResidualBlock2, self).__init__()
|
||||||
|
self.conv1 = nn.Conv1d(channels, channels, kernel_size=2)
|
||||||
|
self.in1 = nn.InstanceNorm1d(channels, affine=True)
|
||||||
|
self.conv2 = nn.Conv1d(channels, channels, kernel_size=2)
|
||||||
|
self.in2 = nn.InstanceNorm1d(channels, affine=True)
|
||||||
|
self.reflection_pad = torch.nn.ReflectionPad1d(1)
|
||||||
|
self.act = ACTIVATIONS[activation]()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
out = self.act(self.in1(self.conv1(self.reflection_pad(x))))
|
||||||
|
out = self.in2(self.conv2(out))
|
||||||
|
out = out + residual
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ViewMaker1(BaseFairseqModel):
|
||||||
|
'''Viewmaker network that stochastically maps a multichannel 2D input to an output of the same size.'''
|
||||||
|
def __init__(self, num_channels=512, distortion_budget=0.02, activation='gelu',
|
||||||
|
clamp=True, frequency_domain=False, downsample_to=False, num_res_blocks=0, num_noise=0):
|
||||||
|
'''Initialize the Viewmaker network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_channels: Number of channels in the input (e.g. 1 for speech, 3 for images)
|
||||||
|
Input will have shape [batch_size, num_channels, height, width]
|
||||||
|
distortion_budget: Distortion budget of the viewmaker (epsilon, in the paper).
|
||||||
|
Controls how strong the perturbations can be.
|
||||||
|
activation: The activation function used in the network ('relu' and 'leaky_relu' currently supported)
|
||||||
|
clamp: Whether to clamp the outputs to [0, 1] (useful to ensure output is, e.g., a valid image)
|
||||||
|
frequency_domain: Whether to apply perturbation (and distortion budget) in the frequency domain.
|
||||||
|
This is useful for shifting the inductive bias of the viewmaker towards more global / textural views.
|
||||||
|
downsample_to: Downsamples the image, applies viewmaker, then upsamples. Possibly useful for
|
||||||
|
higher-resolution inputs, but not evaluaed in the paper.
|
||||||
|
num_res_blocks: Number of residual blocks to use in the network.
|
||||||
|
'''
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.activation = activation
|
||||||
|
self.clamp = clamp
|
||||||
|
self.frequency_domain = frequency_domain
|
||||||
|
self.downsample_to = downsample_to
|
||||||
|
self.distortion_budget = distortion_budget
|
||||||
|
self.num_noise = num_noise
|
||||||
|
self.act = ACTIVATIONS[activation]()
|
||||||
|
|
||||||
|
# Initial convolution layers (+ 1 for noise filter)
|
||||||
|
self.conv1 = ConvLayer2(self.num_channels + self.num_noise, \
|
||||||
|
self.num_channels, kernel_size=2, stride=1)
|
||||||
|
self.in1 = torch.nn.InstanceNorm1d(self.num_channels, affine=True)
|
||||||
|
self.conv2 = ConvLayer2(self.num_channels, self.num_channels, kernel_size=2, stride=1)
|
||||||
|
self.in2 = torch.nn.InstanceNorm1d(self.num_channels, affine=True)
|
||||||
|
self.conv3 = ConvLayer2(self.num_channels, self.num_channels, kernel_size=2, stride=1)
|
||||||
|
self.in3 = torch.nn.InstanceNorm1d(self.num_channels, affine=True)
|
||||||
|
self.conv4 = ConvLayer2(self.num_channels, self.num_channels, kernel_size=2, stride=1)
|
||||||
|
self.in4 = torch.nn.InstanceNorm1d(self.num_channels, affine=True)
|
||||||
|
|
||||||
|
# Residual layers have +N for added random channels
|
||||||
|
if not self.num_noise:
|
||||||
|
self.res1 = ResidualBlock2(self.num_channels + 1)
|
||||||
|
self.res2 = ResidualBlock2(self.num_channels + 2)
|
||||||
|
self.res3 = ResidualBlock2(self.num_channels + 3)
|
||||||
|
self.res4 = ResidualBlock2(self.num_channels + 4)
|
||||||
|
self.res5 = ResidualBlock2(self.num_channels + 5)
|
||||||
|
self.conv5 = ConvLayer2(self.num_channels+self.num_res_blocks, \
|
||||||
|
self.num_channels, kernel_size=2, stride=1)
|
||||||
|
else:
|
||||||
|
self.res1 = ResidualBlock2(self.num_channels)
|
||||||
|
self.res2 = ResidualBlock2(self.num_channels)
|
||||||
|
self.res3 = ResidualBlock2(self.num_channels)
|
||||||
|
self.res4 = ResidualBlock2(self.num_channels)
|
||||||
|
self.res5 = ResidualBlock2(self.num_channels)
|
||||||
|
self.conv5 = ConvLayer2(self.num_channels, \
|
||||||
|
self.num_channels, kernel_size=2, stride=1)
|
||||||
|
|
||||||
|
self.ins5 = torch.nn.InstanceNorm1d(self.num_channels, affine=True)
|
||||||
|
self.conv6 = ConvLayer2(self.num_channels, self.num_channels, kernel_size=2, stride=1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def zero_init(m):
|
||||||
|
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
||||||
|
# actual 0 has symmetry problems
|
||||||
|
init.normal_(m.weight.data, mean=0, std=1e-4)
|
||||||
|
# init.constant_(m.weight.data, 0)
|
||||||
|
init.constant_(m.bias.data, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm1d):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def add_noise_channel(self, x, num=1, bound_multiplier=1):
|
||||||
|
# bound_multiplier is a scalar or a 1D tensor of length batch_size
|
||||||
|
batch_size = x.size(0)
|
||||||
|
filter_size = x.size(-1)
|
||||||
|
shp = (batch_size, num, filter_size)
|
||||||
|
bound_multiplier = torch.tensor(bound_multiplier, device=x.device)
|
||||||
|
noise = torch.rand(shp, device=x.device) * bound_multiplier.view(-1, 1, 1)
|
||||||
|
#if x.dtype == 'torch.cuda.float16':
|
||||||
|
# print('fuck'*100)
|
||||||
|
#noise.type(torch.cuda.float16)
|
||||||
|
noise = noise.half()
|
||||||
|
return torch.cat((x, noise), dim=1)
|
||||||
|
|
||||||
|
def basic_net(self, y, num_res_blocks=3, bound_multiplier=1):
|
||||||
|
if self.num_noise:
|
||||||
|
y = self.add_noise_channel(y, num=self.num_noise, bound_multiplier=bound_multiplier)
|
||||||
|
y = self.act(self.in1(self.conv1(y)))
|
||||||
|
y = self.act(self.in2(self.conv2(y, pad=True)))
|
||||||
|
y = self.act(self.in3(self.conv3(y)))
|
||||||
|
y = self.act(self.in4(self.conv4(y, pad=True)))
|
||||||
|
|
||||||
|
# Features that could be useful for other auxilary layers / losses.
|
||||||
|
# [batch_size, 128]
|
||||||
|
features = y.clone().mean([-1, -2])
|
||||||
|
|
||||||
|
for i, res in enumerate([self.res1, self.res2, self.res3, self.res4, self.res5]):
|
||||||
|
if i < num_res_blocks:
|
||||||
|
if not self.num_noise:
|
||||||
|
y = res(self.add_noise_channel(y, bound_multiplier=bound_multiplier))
|
||||||
|
else:
|
||||||
|
y = res(y)
|
||||||
|
|
||||||
|
y = self.act(self.ins5(self.conv5(y, pad=True)))
|
||||||
|
y = self.conv6(y)
|
||||||
|
|
||||||
|
return y, features
|
||||||
|
|
||||||
|
def get_delta(self, y_pixels, eps=1e-4):
|
||||||
|
'''Constrains the input perturbation by projecting it onto an L1 sphere'''
|
||||||
|
distortion_budget = self.distortion_budget
|
||||||
|
delta = torch.tanh(y_pixels) # Project to [-1, 1]
|
||||||
|
avg_magnitude = delta.abs().mean([1,2], keepdim=True)
|
||||||
|
max_magnitude = distortion_budget
|
||||||
|
delta = delta * max_magnitude / (avg_magnitude + eps)
|
||||||
|
return delta
|
||||||
|
|
||||||
|
def get_delta2(self, y_pixels, padding_mask, eps=1e-4):
|
||||||
|
'''Constrains the input perturbation by projecting it onto an L1 sphere'''
|
||||||
|
if padding_mask is not None:
|
||||||
|
padding_mask_ = torch.logical_not(padding_mask)
|
||||||
|
padding_mask_ = padding_mask_.long().unsqueeze(2)
|
||||||
|
y_pixels = y_pixels.transpose(1,2)
|
||||||
|
y_pixels *= padding_mask_
|
||||||
|
y_pixels = y_pixels.transpose(1,2)
|
||||||
|
|
||||||
|
distortion_budget = self.distortion_budget
|
||||||
|
delta = torch.tanh(y_pixels) # Project to [-1, 1]
|
||||||
|
avg_magnitude = delta.abs().mean([1,2], keepdim=True)
|
||||||
|
max_magnitude = distortion_budget
|
||||||
|
delta = delta * max_magnitude / (avg_magnitude + eps)
|
||||||
|
return delta
|
||||||
|
|
||||||
|
def forward(self, x, padding_mask):
|
||||||
|
x = x.transpose(1,2)
|
||||||
|
if self.downsample_to:
|
||||||
|
# Downsample.
|
||||||
|
x_orig = x
|
||||||
|
x = torch.nn.functional.interpolate(
|
||||||
|
x, size=(self.downsample_to, self.downsample_to), mode='bilinear')
|
||||||
|
y = x
|
||||||
|
|
||||||
|
if self.frequency_domain and 0:
|
||||||
|
# Input to viewmaker is in frequency domain, outputs frequency domain perturbation.
|
||||||
|
# Uses the Discrete Cosine Transform.
|
||||||
|
# shape still [batch_size, C, W, H]
|
||||||
|
y = dct.dct_2d(y)
|
||||||
|
|
||||||
|
y_pixels, features = self.basic_net(y, self.num_res_blocks, bound_multiplier=1)
|
||||||
|
#delta = self.get_delta(y_pixels.clone())
|
||||||
|
delta = self.get_delta2(y_pixels.clone(), padding_mask)
|
||||||
|
|
||||||
|
# Additive perturbation
|
||||||
|
#result = x + delta
|
||||||
|
result = y_pixels
|
||||||
|
|
||||||
|
delta = delta.transpose(1,2)
|
||||||
|
result = result.transpose(1,2)
|
||||||
|
|
||||||
|
return result, delta
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user