mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp760' into scaled_adam_exp765
This commit is contained in:
commit
cff350d8de
@ -1712,10 +1712,13 @@ class ConvNeXt(nn.Module):
|
|||||||
|
|
||||||
self.out_balancer = ActivationBalancer(
|
self.out_balancer = ActivationBalancer(
|
||||||
channels, channel_dim=1,
|
channels, channel_dim=1,
|
||||||
min_positive=0.5, max_positive=0.5,
|
min_positive=0.4, max_positive=0.6,
|
||||||
min_abs=0.25, max_abs=6.0,
|
min_abs=0.25, max_abs=6.0,
|
||||||
)
|
)
|
||||||
|
self.out_whiten = Whiten(num_groups=1,
|
||||||
|
whitening_limit=5.0,
|
||||||
|
prob=(0.025, 0.25),
|
||||||
|
grad_scale=0.01)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
@ -1739,6 +1742,10 @@ class ConvNeXt(nn.Module):
|
|||||||
|
|
||||||
x = bypass + x
|
x = bypass + x
|
||||||
x = self.out_balancer(x)
|
x = self.out_balancer(x)
|
||||||
|
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
|
||||||
|
x = self.out_whiten(x)
|
||||||
|
x = x.transpose(1, 3) # (N, C, H, W)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -1845,13 +1852,13 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1))
|
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1))
|
||||||
|
|
||||||
self.out = nn.Linear(out_width * layer3_channels, out_channels)
|
self.out = nn.Linear(out_width * layer3_channels, out_channels)
|
||||||
# use a much larger than normal grad_scale on this whitening module;
|
# use a larger than normal grad_scale on this whitening module; there is
|
||||||
# there is only one such module, so there is not a concern about adding
|
# only one such module, so there is not a concern about adding together
|
||||||
# together many copies of this extra gradient term.
|
# many copies of this extra gradient term.
|
||||||
self.out_whiten = Whiten(num_groups=1,
|
self.out_whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=_whitening_schedule(4.0),
|
whitening_limit=_whitening_schedule(4.0),
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.05)
|
grad_scale=0.02)
|
||||||
|
|
||||||
self.out_norm = BasicNorm(out_channels)
|
self.out_norm = BasicNorm(out_channels)
|
||||||
self.dropout = Dropout2(dropout)
|
self.dropout = Dropout2(dropout)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user