from local

This commit is contained in:
dohe0342 2023-05-18 16:56:48 +09:00
parent 2a54d49f96
commit 09701ed03a
2 changed files with 42 additions and 0 deletions

View File

@ -229,6 +229,48 @@ class TransformerEncoderAdapter(TransformerEncoder):
return x, layer_results return x, layer_results
class LoRAModule(nn.Module):
"""
Implements a residual adapter based on https://arxiv.org/pdf/1909.08478.pdf
modules similar to the original residual adapter except layernorm location (first -> last)
"""
def __init__(
self,
embedding_dim: float = 768,
layer_num: int = 12,
proj_dim: float = 512,
) -> None:
super().__init__()
self.type = 'linear'
def build_adapter(embedding_dim, proj_dim, type_=self.type):
if type_ == 'conv':
return ConvolutionModule(768, 31)
else:
return nn.Sequential(
#nn.LayerNorm(embedding_dim),
nn.Linear(embedding_dim, proj_dim),
nn.ReLU(),
nn.Linear(proj_dim, embedding_dim),
nn.LayerNorm(embedding_dim),
)
self.adapter_layers = nn.ModuleList(
[build_adapter(embedding_dim, proj_dim, type_=self.type) for _ in range(layer_num)]
)
def forward(self, x, layer_id=-1):
x = x.transpose(0, 1)
residual = x
x = self.adapter_layers[layer_id](x)
x = residual + x
x = x.transpose(0, 1)
return x
class ResidualAdapterModule(nn.Module): class ResidualAdapterModule(nn.Module):
""" """
Implements a residual adapter based on https://arxiv.org/pdf/1909.08478.pdf Implements a residual adapter based on https://arxiv.org/pdf/1909.08478.pdf