diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 7a7a09c27..d0be5af00 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -614,7 +614,9 @@ class RelPositionMultiheadAttention(nn.Module): assert ( head_dim * num_heads == embed_dim ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + q = q * scaling if torch.equal(query, key) and torch.equal(key, value): # self-attention @@ -764,7 +766,7 @@ class RelPositionMultiheadAttention(nn.Module): attn_output_weights = ( matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) + ) # (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, -1 diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 80febc677..c9654cc94 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5_pbs", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved