This commit is contained in:
JinZr 2023-08-14 20:14:50 +08:00
parent eb7180a0e2
commit cc629c09a2
2 changed files with 2 additions and 2 deletions

View File

@ -339,6 +339,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
use_pos_scores = True
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
use_pos_scores = True
use_pos_scores = False
if use_pos_scores:
pos_emb = self.linear_pos(pos_emb)
@ -385,7 +386,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
)
# print(pos_scores.shape)
attn_scores = attn_scores + pos_scores
# attn_scores = attn_scores + pos_scores
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass

View File

@ -1 +0,0 @@
../pruned_transducer_stateless2/beam_search.py