fixed feat dim

This commit is contained in:
zr_jin 2023-07-23 18:20:55 +08:00
parent bb8d016722
commit 3bd2e8e6cc

View File

@ -336,7 +336,7 @@ class AlignmentAttentionModule(nn.Module):
attn_weights = self.cross_attn_weights(lm_pruned, am_pruned, pos_emb)
label_level_am_representation = self.cross_attn(am_pruned, attn_weights)
return label_level_am_representation
return label_level_am_representation.reshape(batch_size, T, prune_range, encoder_dim)
if __name__ == "__main__":