from local

This commit is contained in:
dohe0342 2023-04-11 15:02:37 +09:00
parent f56ad643e6
commit 0979652201
2 changed files with 8 additions and 7 deletions

View File

@ -467,7 +467,8 @@ class Data2VecAudioModel(BaseFairseqModel):
## for prompt tuning
if prompt is not None:
features = torch.cat([features, prompt])
#features = torch.cat([features, prompt])
features = torch.cat([prompt, features])
features = self.layer_norm(features)
@ -494,12 +495,12 @@ class Data2VecAudioModel(BaseFairseqModel):
else:
padding_mask = None
print(padding_mask.size())
print((padding_mask[0] == True).nonzero(as_tuple=True)[0])
print((padding_mask[1] == True).nonzero(as_tuple=True)[0][1])
print((padding_mask[2] == True).nonzero(as_tuple=True)[0][2])
print((padding_mask[3] == True).nonzero(as_tuple=True)[0][3])
exit()
#print(padding_mask.size())
#print((padding_mask[0] == True).nonzero(as_tuple=True)[0])
#print((padding_mask[1] == True).nonzero(as_tuple=True)[0][1])
#print((padding_mask[2] == True).nonzero(as_tuple=True)[0][2])
#print((padding_mask[3] == True).nonzero(as_tuple=True)[0][3])
#exit()
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)