diff --git a/egs/librispeech/ASR/.prepare.sh.swp b/egs/librispeech/ASR/.prepare.sh.swp new file mode 100644 index 000000000..f05a4d785 Binary files /dev/null and b/egs/librispeech/ASR/.prepare.sh.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_gpat/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless_gpat/conformer.py index a591964cc..6937bf5ee 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_gpat/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_gpat/conformer.py @@ -1595,6 +1595,240 @@ class Conv2dSubsampling(nn.Module): return x +ACTIVATIONS = { + 'relu': torch.nn.ReLU, + 'leaky_relu': torch.nn.LeakyReLU, + 'gelu': torch.nn.GELU, +} + + +class FCLayer(torch.nn.Module): + def __init__(self, in_channels, out_channels, activation='gelu'): + super(FCLayer, self).__init__() + in_channels = int(in_channels) + out_channels = int(out_channels) + + self.in_channels = in_channels + self.out_channels = out_channels + + self.linear = torch.nn.Linear(in_channels, out_channels) + self.ins = torch.nn.InstanceNorm1d(out_channels, affine=True) + self.act = ACTIVATIONS[activation]() + + def forward(self, x): + out = self.linear(x) + out = out.transpose(1,2) + out = self.ins(out) + out = self.act(out) + out = out.transpose(1,2) + + return out + + +class ConvLayer2(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, groups=1): + super(ConvLayer2, self).__init__() + reflection_padding = kernel_size // 2 + self.reflection_pad = torch.nn.ReflectionPad1d(reflection_padding) + self.conv1d = ScaledConv1d( + in_channels, out_channels, kernel_size, stride, groups=groups) + + def forward(self, x, pad=False): + if pad: + out = self.reflection_pad(x) + else: + out = x + out = self.conv1d(out) + return out + + +class ResidualBlock2(torch.nn.Module): + """ResidualBlock + introduced in: https://arxiv.org/abs/1512.03385 + recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html + """ + + def __init__(self, channels, activation='gelu'): + super(ResidualBlock2, self).__init__() + self.conv1 = nn.Conv1d(channels, channels, kernel_size=2) + self.in1 = nn.InstanceNorm1d(channels, affine=True) + self.conv2 = nn.Conv1d(channels, channels, kernel_size=2) + self.in2 = nn.InstanceNorm1d(channels, affine=True) + self.reflection_pad = torch.nn.ReflectionPad1d(1) + self.act = ACTIVATIONS[activation]() + + def forward(self, x): + residual = x + out = self.act(self.in1(self.conv1(self.reflection_pad(x)))) + out = self.in2(self.conv2(out)) + out = out + residual + return out + + +class ViewMaker1(BaseFairseqModel): + '''Viewmaker network that stochastically maps a multichannel 2D input to an output of the same size.''' + def __init__(self, num_channels=512, distortion_budget=0.02, activation='gelu', + clamp=True, frequency_domain=False, downsample_to=False, num_res_blocks=0, num_noise=0): + '''Initialize the Viewmaker network. + + Args: + num_channels: Number of channels in the input (e.g. 1 for speech, 3 for images) + Input will have shape [batch_size, num_channels, height, width] + distortion_budget: Distortion budget of the viewmaker (epsilon, in the paper). + Controls how strong the perturbations can be. + activation: The activation function used in the network ('relu' and 'leaky_relu' currently supported) + clamp: Whether to clamp the outputs to [0, 1] (useful to ensure output is, e.g., a valid image) + frequency_domain: Whether to apply perturbation (and distortion budget) in the frequency domain. + This is useful for shifting the inductive bias of the viewmaker towards more global / textural views. + downsample_to: Downsamples the image, applies viewmaker, then upsamples. Possibly useful for + higher-resolution inputs, but not evaluaed in the paper. + num_res_blocks: Number of residual blocks to use in the network. + ''' + super().__init__() + + self.num_channels = num_channels + self.num_res_blocks = num_res_blocks + self.activation = activation + self.clamp = clamp + self.frequency_domain = frequency_domain + self.downsample_to = downsample_to + self.distortion_budget = distortion_budget + self.num_noise = num_noise + self.act = ACTIVATIONS[activation]() + + # Initial convolution layers (+ 1 for noise filter) + self.conv1 = ConvLayer2(self.num_channels + self.num_noise, \ + self.num_channels, kernel_size=2, stride=1) + self.in1 = torch.nn.InstanceNorm1d(self.num_channels, affine=True) + self.conv2 = ConvLayer2(self.num_channels, self.num_channels, kernel_size=2, stride=1) + self.in2 = torch.nn.InstanceNorm1d(self.num_channels, affine=True) + self.conv3 = ConvLayer2(self.num_channels, self.num_channels, kernel_size=2, stride=1) + self.in3 = torch.nn.InstanceNorm1d(self.num_channels, affine=True) + self.conv4 = ConvLayer2(self.num_channels, self.num_channels, kernel_size=2, stride=1) + self.in4 = torch.nn.InstanceNorm1d(self.num_channels, affine=True) + + # Residual layers have +N for added random channels + if not self.num_noise: + self.res1 = ResidualBlock2(self.num_channels + 1) + self.res2 = ResidualBlock2(self.num_channels + 2) + self.res3 = ResidualBlock2(self.num_channels + 3) + self.res4 = ResidualBlock2(self.num_channels + 4) + self.res5 = ResidualBlock2(self.num_channels + 5) + self.conv5 = ConvLayer2(self.num_channels+self.num_res_blocks, \ + self.num_channels, kernel_size=2, stride=1) + else: + self.res1 = ResidualBlock2(self.num_channels) + self.res2 = ResidualBlock2(self.num_channels) + self.res3 = ResidualBlock2(self.num_channels) + self.res4 = ResidualBlock2(self.num_channels) + self.res5 = ResidualBlock2(self.num_channels) + self.conv5 = ConvLayer2(self.num_channels, \ + self.num_channels, kernel_size=2, stride=1) + + self.ins5 = torch.nn.InstanceNorm1d(self.num_channels, affine=True) + self.conv6 = ConvLayer2(self.num_channels, self.num_channels, kernel_size=2, stride=1) + + @staticmethod + def zero_init(m): + if isinstance(m, (nn.Linear, nn.Conv2d)): + # actual 0 has symmetry problems + init.normal_(m.weight.data, mean=0, std=1e-4) + # init.constant_(m.weight.data, 0) + init.constant_(m.bias.data, 0) + elif isinstance(m, nn.BatchNorm1d): + pass + + def add_noise_channel(self, x, num=1, bound_multiplier=1): + # bound_multiplier is a scalar or a 1D tensor of length batch_size + batch_size = x.size(0) + filter_size = x.size(-1) + shp = (batch_size, num, filter_size) + bound_multiplier = torch.tensor(bound_multiplier, device=x.device) + noise = torch.rand(shp, device=x.device) * bound_multiplier.view(-1, 1, 1) + #if x.dtype == 'torch.cuda.float16': + # print('fuck'*100) + #noise.type(torch.cuda.float16) + noise = noise.half() + return torch.cat((x, noise), dim=1) + + def basic_net(self, y, num_res_blocks=3, bound_multiplier=1): + if self.num_noise: + y = self.add_noise_channel(y, num=self.num_noise, bound_multiplier=bound_multiplier) + y = self.act(self.in1(self.conv1(y))) + y = self.act(self.in2(self.conv2(y, pad=True))) + y = self.act(self.in3(self.conv3(y))) + y = self.act(self.in4(self.conv4(y, pad=True))) + + # Features that could be useful for other auxilary layers / losses. + # [batch_size, 128] + features = y.clone().mean([-1, -2]) + + for i, res in enumerate([self.res1, self.res2, self.res3, self.res4, self.res5]): + if i < num_res_blocks: + if not self.num_noise: + y = res(self.add_noise_channel(y, bound_multiplier=bound_multiplier)) + else: + y = res(y) + + y = self.act(self.ins5(self.conv5(y, pad=True))) + y = self.conv6(y) + + return y, features + + def get_delta(self, y_pixels, eps=1e-4): + '''Constrains the input perturbation by projecting it onto an L1 sphere''' + distortion_budget = self.distortion_budget + delta = torch.tanh(y_pixels) # Project to [-1, 1] + avg_magnitude = delta.abs().mean([1,2], keepdim=True) + max_magnitude = distortion_budget + delta = delta * max_magnitude / (avg_magnitude + eps) + return delta + + def get_delta2(self, y_pixels, padding_mask, eps=1e-4): + '''Constrains the input perturbation by projecting it onto an L1 sphere''' + if padding_mask is not None: + padding_mask_ = torch.logical_not(padding_mask) + padding_mask_ = padding_mask_.long().unsqueeze(2) + y_pixels = y_pixels.transpose(1,2) + y_pixels *= padding_mask_ + y_pixels = y_pixels.transpose(1,2) + + distortion_budget = self.distortion_budget + delta = torch.tanh(y_pixels) # Project to [-1, 1] + avg_magnitude = delta.abs().mean([1,2], keepdim=True) + max_magnitude = distortion_budget + delta = delta * max_magnitude / (avg_magnitude + eps) + return delta + + def forward(self, x, padding_mask): + x = x.transpose(1,2) + if self.downsample_to: + # Downsample. + x_orig = x + x = torch.nn.functional.interpolate( + x, size=(self.downsample_to, self.downsample_to), mode='bilinear') + y = x + + if self.frequency_domain and 0: + # Input to viewmaker is in frequency domain, outputs frequency domain perturbation. + # Uses the Discrete Cosine Transform. + # shape still [batch_size, C, W, H] + y = dct.dct_2d(y) + + y_pixels, features = self.basic_net(y, self.num_res_blocks, bound_multiplier=1) + #delta = self.get_delta(y_pixels.clone()) + delta = self.get_delta2(y_pixels.clone(), padding_mask) + + # Additive perturbation + #result = x + delta + result = y_pixels + + delta = delta.transpose(1,2) + result = result.transpose(1,2) + + return result, delta + + if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1)