mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
updated
This commit is contained in:
parent
eb7180a0e2
commit
cc629c09a2
@ -339,6 +339,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
use_pos_scores = True
|
||||
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||
use_pos_scores = True
|
||||
use_pos_scores = False
|
||||
|
||||
if use_pos_scores:
|
||||
pos_emb = self.linear_pos(pos_emb)
|
||||
@ -385,7 +386,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
|
||||
)
|
||||
# print(pos_scores.shape)
|
||||
attn_scores = attn_scores + pos_scores
|
||||
# attn_scores = attn_scores + pos_scores
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
pass
|
||||
|
@ -1 +0,0 @@
|
||||
../pruned_transducer_stateless2/beam_search.py
|
Loading…
x
Reference in New Issue
Block a user