Add BasicNorm to ConvNeXt; increase prob given to CutoffEstimator; adjust default probs of ActivationBalancer.

This commit is contained in:
Daniel Povey 2022-12-18 14:14:15 +08:00
parent a424a73881
commit 5e1bf8b8ec
2 changed files with 16 additions and 12 deletions

View File

@ -235,8 +235,8 @@ class CutoffEstimator:
"""
Estimates cutoffs of an arbitrary numerical quantity such that a specified
proportion of items will be above the cutoff on average.
p is the proportion of items that should be above the cutoff.
p is the proportion of items that should be above the cutoff.
"""
def __init__(self, p: float):
self.p = p
@ -614,11 +614,10 @@ class ActivationBalancer(torch.nn.Module):
max_abs: the maximum average-absolute-value difference from the mean
value per channel, which we allow, before we start to modify
the derivatives to prevent this.
min_prob: determines the minimum probability with which we modify the
prob: determines the minimum probability with which we modify the
gradients for the {min,max}_positive and {min,max}_abs constraints,
on each forward(). This is done randomly to prevent all layers
from doing it at the same time. Early in training we may use
higher probabilities than this; it will decay to this value.
from doing it at the same time.
"""
def __init__(
self,
@ -637,11 +636,11 @@ class ActivationBalancer(torch.nn.Module):
if prob is None:
prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1), default=0.4)
prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
self.prob = prob
# 10% of the time we will return and do nothing because memory usage
# is too high.
self.mem_cutoff = CutoffEstimator(0.1)
# 20% of the time we will return and do nothing because memory usage is
# too high.
self.mem_cutoff = CutoffEstimator(0.2)
# actually self.num_channels is no longer needed except for an assertion.
self.num_channels = num_channels
@ -838,9 +837,9 @@ class Whiten(nn.Module):
self.whitening_limit = whitening_limit
self.grad_scale = grad_scale
# 10% of the time we will return and do nothing because memory usage
# 20% of the time we will return and do nothing because memory usage
# is too high.
self.mem_cutoff = CutoffEstimator(0.1)
self.mem_cutoff = CutoffEstimator(0.2)
if isinstance(prob, float):
assert 0 < prob <= 1

View File

@ -1799,7 +1799,10 @@ class Conv2dSubsampling(nn.Module):
)
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels),
ConvNeXt(layer2_channels))
ConvNeXt(layer2_channels),
BasicNorm(layer2_channels,
channel_dim=1))
self.conv2 = nn.Sequential(
nn.Conv2d(
@ -1815,7 +1818,9 @@ class Conv2dSubsampling(nn.Module):
)
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels),
ConvNeXt(layer3_channels))
ConvNeXt(layer3_channels),
BasicNorm(layer3_channels,
channel_dim=1))
out_height = (((in_channels - 1) // 2) - 1) // 2