from local

This commit is contained in:
dohe0342 2023-04-11 15:03:06 +09:00
parent 0979652201
commit 3b1ca50e6f
2 changed files with 7 additions and 7 deletions

View File

@ -465,13 +465,6 @@ class Data2VecAudioModel(BaseFairseqModel):
features = features.transpose(1, 2)
## for prompt tuning
if prompt is not None:
#features = torch.cat([features, prompt])
features = torch.cat([prompt, features])
features = self.layer_norm(features)
orig_padding_mask = padding_mask
if padding_mask is not None and padding_mask.any():
@ -495,6 +488,13 @@ class Data2VecAudioModel(BaseFairseqModel):
else:
padding_mask = None
## for prompt tuning
if prompt is not None:
#features = torch.cat([features, prompt])
features = torch.cat([prompt, features])
features = self.layer_norm(features)
#print(padding_mask.size())
#print((padding_mask[0] == True).nonzero(as_tuple=True)[0])
#print((padding_mask[1] == True).nonzero(as_tuple=True)[0][1])