Make the in_proj of feedforward modules also be a LinearWithAuxLoss.

This commit is contained in:
Daniel Povey 2022-11-26 12:13:31 +08:00
parent 029f5869c4
commit d9c7e4f216

View File

@ -1384,7 +1384,9 @@ class FeedforwardModule(nn.Module):
feedforward_dim: int,
dropout: float):
super(FeedforwardModule, self).__init__()
self.in_proj = nn.Linear(embed_dim, feedforward_dim)
self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim,
aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01)))
self.hidden_balancer = ActivationBalancer(feedforward_dim,
channel_dim=-1, max_abs=10.0,
min_prob=0.25)
@ -1392,8 +1394,7 @@ class FeedforwardModule(nn.Module):
self.dropout = nn.Dropout(dropout)
self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim,
initial_scale=0.01,
aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01)),
)
aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01)))
self.out_whiten = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(7.5),
prob=(0.025, 0.25),