# Copyright (c) Facebook, Inc. and its affiliates. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import math from typing import Callable, List, Optional import torch import torch.nn as nn import torch.nn.functional as F def relu_squared(x: torch.Tensor): return F.relu(x).pow(2) def gelu_accurate(x): if not hasattr(gelu_accurate, "_a"): gelu_accurate._a = math.sqrt(2 / math.pi) return ( 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) ) def is_xla_tensor(tensor): return torch.is_tensor(tensor) and tensor.device.type == "xla" def index_put(tensor, indices, value): if is_xla_tensor(tensor): for _ in range(indices.dim(), tensor.dim()): indices = indices.unsqueeze(-1) if indices.size(-1) < tensor.size(-1): indices = indices.expand_as(tensor) tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) else: tensor[indices] = value return tensor def pad_to_multiple(x, multiple, dim=-1, value=0): # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41 if x is None: return None, 0 tsz = x.size(dim) m = tsz / multiple remainder = math.ceil(m) * multiple - tsz if m.is_integer(): return x, 0 pad_offset = (0,) * (-1 - dim) * 2 return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder def gelu(x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.gelu(x.float()).type_as(x) def get_activation_fn(activation: str) -> Callable: """Returns the activation function corresponding to `activation`""" if activation == "relu": return F.relu elif activation == "relu_squared": return relu_squared elif activation == "gelu": return gelu elif activation == "gelu_fast": return gelu_accurate elif activation == "gelu_accurate": return gelu_accurate elif activation == "tanh": return torch.tanh elif activation == "linear": return lambda x: x elif activation == "swish": return torch.nn.SiLU else: raise RuntimeError("--activation-fn {} not supported".format(activation)) class SamePad(nn.Module): def __init__(self, kernel_size, causal=False): super().__init__() if causal: self.remove = kernel_size - 1 else: self.remove = 1 if kernel_size % 2 == 0 else 0 def forward(self, x): if self.remove > 0: x = x[:, :, : -self.remove] return x class SamePad2d(nn.Module): def __init__(self, kernel_size): super().__init__() self.remove = 1 if kernel_size % 2 == 0 else 0 def forward(self, x): assert len(x.size()) == 4 if self.remove > 0: x = x[:, :, : -self.remove, : -self.remove] return x class TransposeLast(nn.Module): def __init__(self, deconstruct_idx=None, tranpose_dim=-2): super().__init__() self.deconstruct_idx = deconstruct_idx self.tranpose_dim = tranpose_dim def forward(self, x): if self.deconstruct_idx is not None: x = x[self.deconstruct_idx] return x.transpose(self.tranpose_dim, -1) try: from apex.normalization import FusedLayerNorm as _FusedLayerNorm has_fused_layernorm = True class FusedLayerNorm(_FusedLayerNorm): @torch.jit.unused def forward(self, x): if not x.is_cuda: return super().forward(x) else: with torch.cuda.device(x.device): return super().forward(x) except ImportError: has_fused_layernorm = False def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): if torch.jit.is_scripting() or torch.jit.is_tracing(): export = True if not export and torch.cuda.is_available() and has_fused_layernorm: return FusedLayerNorm(normalized_shape, eps, elementwise_affine) return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) class Fp32LayerNorm(nn.LayerNorm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, input): output = F.layer_norm( input.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps, ) return output.type_as(input) class Fp32GroupNorm(nn.GroupNorm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, input): output = F.group_norm( input.float(), self.num_groups, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps, ) return output.type_as(input) def softmax(x, dim: int, onnx_trace: bool = False): if onnx_trace: return F.softmax(x.float(), dim=dim) else: return F.softmax(x, dim=dim, dtype=torch.float32) def quant_noise(module, p, block_size): """ Wraps modules and applies quantization noise to the weights for subsequent quantization with Iterative Product Quantization as described in "Training with Quantization Noise for Extreme Model Compression" Args: - module: nn.Module - p: amount of Quantization Noise - block_size: size of the blocks for subsequent quantization with iPQ Remarks: - Module weights must have the right sizes wrt the block size - Only Linear, Embedding and Conv2d modules are supported for the moment - For more detail on how to quantize by blocks with convolutional weights, see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" - We implement the simplest form of noise here as stated in the paper which consists in randomly dropping blocks """ # if no quantization noise, don't register hook if p <= 0: return module # supported modules assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) # test whether module.weight has the right sizes wrt block_size is_conv = module.weight.ndim == 4 # 2D matrix if not is_conv: assert ( module.weight.size(1) % block_size == 0 ), "Input features must be a multiple of block sizes" # 4D matrix else: # 1x1 convolutions if module.kernel_size == (1, 1): assert ( module.in_channels % block_size == 0 ), "Input channels must be a multiple of block sizes" # regular convolutions else: k = module.kernel_size[0] * module.kernel_size[1] assert k % block_size == 0, "Kernel size must be a multiple of block size" def _forward_pre_hook(mod, input): # no noise for evaluation if mod.training: if not is_conv: # gather weight and sizes weight = mod.weight in_features = weight.size(1) out_features = weight.size(0) # split weight matrix into blocks and randomly drop selected blocks mask = torch.zeros( in_features // block_size * out_features, device=weight.device, ) mask.bernoulli_(p) mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) else: # gather weight and sizes weight = mod.weight in_channels = mod.in_channels out_channels = mod.out_channels # split weight matrix into blocks and randomly drop selected blocks if mod.kernel_size == (1, 1): mask = torch.zeros( int(in_channels // block_size * out_channels), device=weight.device, ) mask.bernoulli_(p) mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) else: mask = torch.zeros( weight.size(0), weight.size(1), device=weight.device ) mask.bernoulli_(p) mask = ( mask.unsqueeze(2) .unsqueeze(3) .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) ) # scale weights and apply mask mask = mask.to( torch.bool ) # x.bool() is not currently supported in TorchScript s = 1 / (1 - p) mod.weight.data = s * weight.masked_fill(mask, 0) module.register_forward_pre_hook(_forward_pre_hook) return module class FairseqDropout(nn.Module): def __init__(self, p, module_name=None): super().__init__() self.p = p self.module_name = module_name self.apply_during_inference = False def forward(self, x, inplace: bool = False): if self.p > 0 and (self.training or self.apply_during_inference): return F.dropout(x, p=self.p, training=True, inplace=inplace) else: return x def make_generation_fast_( self, name: str, retain_dropout: bool = False, retain_dropout_modules: Optional[List[str]] = None, **kwargs ): if retain_dropout: if retain_dropout_modules is not None and self.module_name is None: pass elif ( retain_dropout_modules is None # if None, apply to all modules or self.module_name in retain_dropout_modules ): self.apply_during_inference = True class GradMultiply(torch.autograd.Function): @staticmethod def forward(ctx, x, scale): ctx.scale = scale res = x.new(x) return res @staticmethod def backward(ctx, grad): return grad * ctx.scale, None