Add BasicNorm to ConvNeXt; increase prob given to CutoffEstimator; adjust default probs of ActivationBalancer.

This commit is contained in:
Daniel Povey 2022-12-18 14:14:15 +08:00
parent a424a73881
commit 5e1bf8b8ec
2 changed files with 16 additions and 12 deletions

View File

@ -235,8 +235,8 @@ class CutoffEstimator:
""" """
Estimates cutoffs of an arbitrary numerical quantity such that a specified Estimates cutoffs of an arbitrary numerical quantity such that a specified
proportion of items will be above the cutoff on average. proportion of items will be above the cutoff on average.
p is the proportion of items that should be above the cutoff.
p is the proportion of items that should be above the cutoff.
""" """
def __init__(self, p: float): def __init__(self, p: float):
self.p = p self.p = p
@ -614,11 +614,10 @@ class ActivationBalancer(torch.nn.Module):
max_abs: the maximum average-absolute-value difference from the mean max_abs: the maximum average-absolute-value difference from the mean
value per channel, which we allow, before we start to modify value per channel, which we allow, before we start to modify
the derivatives to prevent this. the derivatives to prevent this.
min_prob: determines the minimum probability with which we modify the prob: determines the minimum probability with which we modify the
gradients for the {min,max}_positive and {min,max}_abs constraints, gradients for the {min,max}_positive and {min,max}_abs constraints,
on each forward(). This is done randomly to prevent all layers on each forward(). This is done randomly to prevent all layers
from doing it at the same time. Early in training we may use from doing it at the same time.
higher probabilities than this; it will decay to this value.
""" """
def __init__( def __init__(
self, self,
@ -637,11 +636,11 @@ class ActivationBalancer(torch.nn.Module):
if prob is None: if prob is None:
prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1), default=0.4) prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
self.prob = prob self.prob = prob
# 10% of the time we will return and do nothing because memory usage # 20% of the time we will return and do nothing because memory usage is
# is too high. # too high.
self.mem_cutoff = CutoffEstimator(0.1) self.mem_cutoff = CutoffEstimator(0.2)
# actually self.num_channels is no longer needed except for an assertion. # actually self.num_channels is no longer needed except for an assertion.
self.num_channels = num_channels self.num_channels = num_channels
@ -838,9 +837,9 @@ class Whiten(nn.Module):
self.whitening_limit = whitening_limit self.whitening_limit = whitening_limit
self.grad_scale = grad_scale self.grad_scale = grad_scale
# 10% of the time we will return and do nothing because memory usage # 20% of the time we will return and do nothing because memory usage
# is too high. # is too high.
self.mem_cutoff = CutoffEstimator(0.1) self.mem_cutoff = CutoffEstimator(0.2)
if isinstance(prob, float): if isinstance(prob, float):
assert 0 < prob <= 1 assert 0 < prob <= 1

View File

@ -1799,7 +1799,10 @@ class Conv2dSubsampling(nn.Module):
) )
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels), self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels),
ConvNeXt(layer2_channels)) ConvNeXt(layer2_channels),
BasicNorm(layer2_channels,
channel_dim=1))
self.conv2 = nn.Sequential( self.conv2 = nn.Sequential(
nn.Conv2d( nn.Conv2d(
@ -1815,7 +1818,9 @@ class Conv2dSubsampling(nn.Module):
) )
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels), self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels),
ConvNeXt(layer3_channels)) ConvNeXt(layer3_channels),
BasicNorm(layer3_channels,
channel_dim=1))
out_height = (((in_channels - 1) // 2) - 1) // 2 out_height = (((in_channels - 1) // 2) - 1) // 2