mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
updated
This commit is contained in:
parent
eb7180a0e2
commit
cc629c09a2
@ -339,6 +339,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
use_pos_scores = True
|
use_pos_scores = True
|
||||||
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||||
use_pos_scores = True
|
use_pos_scores = True
|
||||||
|
use_pos_scores = False
|
||||||
|
|
||||||
if use_pos_scores:
|
if use_pos_scores:
|
||||||
pos_emb = self.linear_pos(pos_emb)
|
pos_emb = self.linear_pos(pos_emb)
|
||||||
@ -385,7 +386,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
|
storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
|
||||||
)
|
)
|
||||||
# print(pos_scores.shape)
|
# print(pos_scores.shape)
|
||||||
attn_scores = attn_scores + pos_scores
|
# attn_scores = attn_scores + pos_scores
|
||||||
|
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -1 +0,0 @@
|
|||||||
../pruned_transducer_stateless2/beam_search.py
|
|
||||||
Loading…
x
Reference in New Issue
Block a user