from local

This commit is contained in:
dohe0342 2023-05-18 17:10:45 +09:00
parent 383cb553f4
commit 016876a900
2 changed files with 1 additions and 6 deletions

View File

@ -40,12 +40,7 @@ logger = logging.getLogger().setLevel(logging.INFO)
class TransformerEncoderAdapter(TransformerEncoder):
def __init__(self, args: Wav2Vec2Config):
super().__init__(args)
self.adapters = ResidualAdapterModule(proj_dim=512)
for p in self.adapters.parameters():
p.data /= 10.
#p.data = nn.Parameter(torch.zeros(p.size()).to('cuda'))
#p.data = nn.Parameter(torch.randn(p.size()).to('cuda')/20.)
self.lora = LoRAModule()
def forward(self, x, padding_mask=None, layer=None, tgt_layer=None):
x, layer_results = self.extract_features_with_adapter(