# Copyright (c) Facebook, Inc. and its affiliates. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import math from typing import List, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from attention_module import MultiheadAttention, init_bert_params from utils import ( Fp32GroupNorm, Fp32LayerNorm, LayerNorm, SamePad, TransposeLast, get_activation_fn, index_put, pad_to_multiple, ) class ConvFeatureExtractionModel(nn.Module): def __init__( self, conv_layers: List[Tuple[int, int, int]], dropout: float = 0.0, mode: str = "default", conv_bias: bool = False, ): super().__init__() assert mode in {"default", "layer_norm"} def block( n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False, ): def make_conv(): conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) nn.init.kaiming_normal_(conv.weight) return conv assert ( is_layer_norm and is_group_norm ) == False, "layer norm and group norm are exclusive" if is_layer_norm: return nn.Sequential( make_conv(), nn.Dropout(p=dropout), nn.Sequential( TransposeLast(), Fp32LayerNorm(dim, elementwise_affine=True), TransposeLast(), ), nn.GELU(), ) elif is_group_norm: return nn.Sequential( make_conv(), nn.Dropout(p=dropout), Fp32GroupNorm(dim, dim, affine=True), nn.GELU(), ) else: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) in_d = 1 self.conv_layers = nn.ModuleList() for i, cl in enumerate(conv_layers): assert len(cl) == 3, "invalid conv definition: " + str(cl) (dim, k, stride) = cl self.conv_layers.append( block( in_d, dim, k, stride, is_layer_norm=mode == "layer_norm", is_group_norm=mode == "default" and i == 0, conv_bias=conv_bias, ) ) in_d = dim def forward(self, x): # BxT -> BxCxT x = x.unsqueeze(1) for conv in self.conv_layers: x = conv(x) return x def make_conv_pos(e, k, g, is_batch_norm=False): pos_conv = nn.Conv1d( e, e, kernel_size=k, padding=k // 2, groups=g, ) dropout = 0 std = math.sqrt((4 * (1.0 - dropout)) / (k * e)) nn.init.normal_(pos_conv.weight, mean=0, std=std) nn.init.constant_(pos_conv.bias, 0) if not is_batch_norm: pos_conv = nn.utils.parametrizations.weight_norm(pos_conv, name="weight", dim=2) pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU()) else: batch_norm = nn.BatchNorm1d(e) pos_conv = nn.Sequential(batch_norm, pos_conv, SamePad(k), nn.GELU()) return pos_conv class TransformerEncoder(nn.Module): def build_encoder_layer(self, args, **kwargs): if args.layer_type == "transformer": layer = TransformerSentenceEncoderLayer( embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first, ) elif args.layer_type == "trf_adp": use_adp = False if args.adp_trf_idx == "all": use_adp = True else: adp_trf_idx = list( range(*[int(g) for g in args.adp_trf_idx.split(":")]) ) if kwargs.get("layer_idx", None) in adp_trf_idx: use_adp = True if use_adp: layer = TransformerSentenceEncoderWithAdapterLayer( embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first, adapter_num=args.adp_num, adapter_dim=args.adp_dim, adapter_act_fn=args.adp_act_fn, ) else: layer = TransformerSentenceEncoderLayer( embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first, ) # layer = fsdp_wrap(layer) # if args.checkpoint_activations: # layer = checkpoint_wrapper(layer) return layer def __init__(self, args): super().__init__() self.dropout = args.dropout self.embedding_dim = args.encoder_embed_dim self.required_seq_len_multiple = args.required_seq_len_multiple pos_conv_depth = getattr(args, "pos_conv_depth", 1) if pos_conv_depth > 1: num_layers = args.pos_conv_depth k = max(3, args.conv_pos // num_layers) def make_conv_block(e, k, g, l): return nn.Sequential( *[ nn.Sequential( nn.Conv1d( e, e, kernel_size=k, padding=k // 2, groups=g, ), SamePad(k), TransposeLast(), LayerNorm(e, elementwise_affine=False), TransposeLast(), nn.GELU(), ) for _ in range(l) ] ) self.pos_conv = make_conv_block( self.embedding_dim, k, args.conv_pos_groups, num_layers ) else: self.pos_conv = make_conv_pos( self.embedding_dim, args.conv_pos, args.conv_pos_groups, is_batch_norm=args.conv_pos_batch_norm if hasattr(args, "conv_pos_batch_norm") else False, ) self.layers = nn.ModuleList( [ self.build_encoder_layer(args, layer_idx=ii) for ii in range(args.encoder_layers) ] ) self.layer_norm_first = args.layer_norm_first self.layer_norm = LayerNorm(self.embedding_dim) self.layerdrop = args.encoder_layerdrop self.apply(init_bert_params) def forward(self, x, padding_mask=None, layer=None, corpus_key=None): x, layer_results = self.extract_features( x, padding_mask, layer, corpus_key=corpus_key ) if self.layer_norm_first and layer is None: x = self.layer_norm(x) return x, layer_results def extract_features( self, x, padding_mask=None, tgt_layer=None, min_layer=0, corpus_key=None, ): if padding_mask is not None: x = index_put(x, padding_mask, 0) x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = x_conv.transpose(1, 2) x = x + x_conv if not self.layer_norm_first: x = self.layer_norm(x) # pad to the sequence length dimension x, pad_length = pad_to_multiple( x, self.required_seq_len_multiple, dim=-2, value=0 ) if pad_length > 0 and padding_mask is None: padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) padding_mask[:, -pad_length:] = True else: padding_mask, _ = pad_to_multiple( padding_mask, self.required_seq_len_multiple, dim=-1, value=True ) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) layer_results = [] r = None for i, layer in enumerate(self.layers): dropout_probability = np.random.random() if self.layerdrop > 0 else 1 if not self.training or (dropout_probability > self.layerdrop): layer_check = layer # if isinstance(layer, FullyShardedDataParallel): # layer_check = layer.unwrapped_module if (corpus_key is None) or ( not isinstance( layer_check, (TransformerSentenceEncoderWithAdapterLayer,), ) ): x, (z, lr) = layer( x, self_attn_padding_mask=padding_mask, need_weights=False, ) else: x, (z, lr) = layer( x, self_attn_padding_mask=padding_mask, need_weights=False, corpus_key=corpus_key, ) if i >= min_layer: layer_results.append((x, z, lr)) if i == tgt_layer: r = x break if r is not None: x = r # T x B x C -> B x T x C x = x.transpose(0, 1) # undo paddding if pad_length > 0: x = x[:, :-pad_length] def undo_pad(a, b, c): return ( a[:-pad_length], b[:-pad_length] if b is not None else b, c[:-pad_length], ) layer_results = [undo_pad(*u) for u in layer_results] return x, layer_results def max_positions(self): """Maximum output length supported by the encoder.""" return self.args.max_positions def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict class TransformerSentenceEncoderLayer(nn.Module): """ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained models. """ def __init__( self, embedding_dim: float = 768, ffn_embedding_dim: float = 3072, num_attention_heads: int = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, activation_fn: str = "relu", layer_norm_first: bool = False, ) -> None: super().__init__() # Initialize parameters self.embedding_dim = embedding_dim self.dropout = dropout self.activation_dropout = activation_dropout # Initialize blocks self.activation_fn = get_activation_fn(activation_fn) self.self_attn = MultiheadAttention( self.embedding_dim, num_attention_heads, dropout=attention_dropout, self_attention=True, ) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(self.activation_dropout) self.dropout3 = nn.Dropout(dropout) self.layer_norm_first = layer_norm_first # layer norm associated with the self attention layer self.self_attn_layer_norm = LayerNorm(self.embedding_dim) self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) # layer norm associated with the position wise feed-forward NN self.final_layer_norm = LayerNorm(self.embedding_dim) def forward( self, x: torch.Tensor, self_attn_mask: torch.Tensor = None, self_attn_padding_mask: torch.Tensor = None, need_weights: bool = False, att_args=None, ): """ LayerNorm is applied either before or after the self-attention/ffn modules similar to the original Transformer imlementation. """ residual = x if self.layer_norm_first: x = self.self_attn_layer_norm(x) x, attn = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, attn_mask=self_attn_mask, need_weights=False, ) x = self.dropout1(x) x = residual + x residual = x x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = self.dropout2(x) x = self.fc2(x) layer_result = x x = self.dropout3(x) x = residual + x else: x, attn = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, need_weights=False, ) x = self.dropout1(x) x = residual + x x = self.self_attn_layer_norm(x) residual = x x = self.activation_fn(self.fc1(x)) x = self.dropout2(x) x = self.fc2(x) layer_result = x x = self.dropout3(x) x = residual + x x = self.final_layer_norm(x) return x, (attn, layer_result) class AdapterFast(nn.Module): def __init__(self, adapter_num, input_dim, hidden_dim, act_fn): """ Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList orto speed up training throughput. """ super().__init__() self.adapter_num = adapter_num self.input_dim = input_dim self.hidden_dim = hidden_dim self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim)) self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim)) self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim)) self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim)) self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim)) self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim)) self.act_fn = nn.Identity() if act_fn == "relu": self.act_fn = nn.ReLU() elif act_fn == "gelu": self.act_fn = nn.GELU() elif act_fn == "selu": self.act_fn = nn.SELU() else: raise ValueError(f"unsupported {act_fn}") self.input_dim = input_dim self.reset_parameters() def reset_parameters(self): for ii in range(self.adapter_num): nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5)) nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5)) fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii]) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.b_a[ii], -bound, bound) fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii]) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.b_b[ii], -bound, bound) nn.init.ones_(self.ln_W) nn.init.zeros_(self.ln_b) def forward(self, x, adapter_id): ii = adapter_id h = x h = F.layer_norm(h, (self.input_dim,), self.ln_W[ii], self.ln_b[ii]) h = F.linear(h, self.W_a[ii], self.b_a[ii]) h = self.act_fn(h) h = F.linear(h, self.W_b[ii], self.b_b[ii]) outputs = h return outputs def extra_repr(self): return "adapter={}, input_dim={}, hidden_dim={}".format( self.adapter_num, self.input_dim, self.hidden_dim ) class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer): """ Implements a Transformer Encoder Layer with adapters used in BERT/XLM style pre-trained models. An adapter module is added along with vanilla Transformer module. """ def __init__( self, embedding_dim: float = 768, ffn_embedding_dim: float = 3072, num_attention_heads: int = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, activation_fn: str = "relu", layer_norm_first: bool = False, adapter_num=201, adapter_dim=64, adapter_act_fn="relu", ) -> None: super().__init__( embedding_dim=embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, dropout=dropout, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, layer_norm_first=layer_norm_first, ) self.adapter_num = adapter_num self.adapter_dim = adapter_dim self.adapter_layer = AdapterFast( adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn ) def forward( self, x: torch.Tensor, self_attn_mask: torch.Tensor = None, self_attn_padding_mask: torch.Tensor = None, need_weights: bool = False, att_args=None, corpus_key=None, ): x, (attn, layer_result) = super().forward( x=x, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, need_weights=need_weights, att_args=att_args, ) assert corpus_key is not None assert len(set(corpus_key)) == 1, f"corpus_key items are not same {corpus_key}" y = self.adapter_layer(x, corpus_key[0]) x = x + y return x, (attn, layer_result)