Scale up pos_bias_u and pos_bias_v before use.

This commit is contained in:
Daniel Povey 2022-03-11 14:37:52 +08:00
parent e3e14cf7a4
commit ab9a17413a
2 changed files with 4 additions and 2 deletions

View File

@ -614,7 +614,9 @@ class RelPositionMultiheadAttention(nn.Module):
assert ( assert (
head_dim * num_heads == embed_dim head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads" ), "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5 scaling = float(head_dim) ** -0.5
q = q * scaling
if torch.equal(query, key) and torch.equal(key, value): if torch.equal(query, key) and torch.equal(key, value):
# self-attention # self-attention
@ -764,7 +766,7 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output_weights = ( attn_output_weights = (
matrix_ac + matrix_bd matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2) ) # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1 bsz * num_heads, tgt_len, -1

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5", default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5_pbs",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved