diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp index fdca8c5ac..9d4f01465 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/data2vec_audio.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/data2vec_audio.py index 8cae787ab..20adb7724 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/data2vec_audio.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/data2vec_audio.py @@ -465,13 +465,6 @@ class Data2VecAudioModel(BaseFairseqModel): features = features.transpose(1, 2) - ## for prompt tuning - if prompt is not None: - #features = torch.cat([features, prompt]) - features = torch.cat([prompt, features]) - - features = self.layer_norm(features) - orig_padding_mask = padding_mask if padding_mask is not None and padding_mask.any(): @@ -495,6 +488,13 @@ class Data2VecAudioModel(BaseFairseqModel): else: padding_mask = None + ## for prompt tuning + if prompt is not None: + #features = torch.cat([features, prompt]) + features = torch.cat([prompt, features]) + + features = self.layer_norm(features) + #print(padding_mask.size()) #print((padding_mask[0] == True).nonzero(as_tuple=True)[0]) #print((padding_mask[1] == True).nonzero(as_tuple=True)[0][1])