diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index f876b15ce..b2de7232f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -863,6 +863,10 @@ class RelPositionMultiheadAttention(nn.Module): self.linear_pos = ScaledLinear(embed_dim, attention_dim // 2, bias=False, initial_scale=0.05) + # the following are for diagnosics only, see --print-diagnostics option + self.copy_pos_query = nn.Identity() + self.copy_query = nn.Identity() + self.in_balancer = ActivationBalancer(3 * attention_dim, channel_dim=-1, max_abs=5.0) self.out_proj = ScaledLinear( @@ -1008,9 +1012,11 @@ class RelPositionMultiheadAttention(nn.Module): q, k, pv = x.chunk(3, dim=-1) p, v = pv.chunk(2, dim=-1) - k = self.whiten_keys(k) # does nothing in the forward pass. v = self.whiten_values(v) # does nothing in the forward pass. + q = self.copy_query(q) # for diagnostics only, does nothing. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + if attn_mask is not None: assert ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 18d854279..e5d4f73ad 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -230,7 +230,7 @@ def get_parser(): parser.add_argument( "--initial-lr", type=float, - default=0.04, + default=0.05, help="The initial learning rate. This value should not need " "to be changed.", )