mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Increase initial-lr from 0.04 to 0.05, plus changes for diagnostics
This commit is contained in:
parent
2675944f01
commit
b988bc0e33
@ -863,6 +863,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.linear_pos = ScaledLinear(embed_dim, attention_dim // 2, bias=False,
|
||||
initial_scale=0.05)
|
||||
|
||||
# the following are for diagnosics only, see --print-diagnostics option
|
||||
self.copy_pos_query = nn.Identity()
|
||||
self.copy_query = nn.Identity()
|
||||
|
||||
self.in_balancer = ActivationBalancer(3 * attention_dim,
|
||||
channel_dim=-1, max_abs=5.0)
|
||||
self.out_proj = ScaledLinear(
|
||||
@ -1008,9 +1012,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
q, k, pv = x.chunk(3, dim=-1)
|
||||
p, v = pv.chunk(2, dim=-1)
|
||||
|
||||
|
||||
k = self.whiten_keys(k) # does nothing in the forward pass.
|
||||
v = self.whiten_values(v) # does nothing in the forward pass.
|
||||
q = self.copy_query(q) # for diagnostics only, does nothing.
|
||||
p = self.copy_pos_query(p) # for diagnostics only, does nothing.
|
||||
|
||||
|
||||
if attn_mask is not None:
|
||||
assert (
|
||||
|
||||
@ -230,7 +230,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--initial-lr",
|
||||
type=float,
|
||||
default=0.04,
|
||||
default=0.05,
|
||||
help="The initial learning rate. This value should not need "
|
||||
"to be changed.",
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user