mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Scale up pos_bias_u and pos_bias_v before use.
This commit is contained in:
parent
e3e14cf7a4
commit
ab9a17413a
@ -614,7 +614,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
scaling = float(head_dim) ** -0.5
|
||||
q = q * scaling
|
||||
|
||||
if torch.equal(query, key) and torch.equal(key, value):
|
||||
# self-attention
|
||||
@ -764,7 +766,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
attn_output_weights = (
|
||||
matrix_ac + matrix_bd
|
||||
) * scaling # (batch, head, time1, time2)
|
||||
) # (batch, head, time1, time2)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, -1
|
||||
|
@ -110,7 +110,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5",
|
||||
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5_pbs",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
Loading…
x
Reference in New Issue
Block a user