mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add BasicNorm to ConvNeXt; increase prob given to CutoffEstimator; adjust default probs of ActivationBalancer.
This commit is contained in:
parent
a424a73881
commit
5e1bf8b8ec
@ -235,8 +235,8 @@ class CutoffEstimator:
|
||||
"""
|
||||
Estimates cutoffs of an arbitrary numerical quantity such that a specified
|
||||
proportion of items will be above the cutoff on average.
|
||||
p is the proportion of items that should be above the cutoff.
|
||||
|
||||
p is the proportion of items that should be above the cutoff.
|
||||
"""
|
||||
def __init__(self, p: float):
|
||||
self.p = p
|
||||
@ -614,11 +614,10 @@ class ActivationBalancer(torch.nn.Module):
|
||||
max_abs: the maximum average-absolute-value difference from the mean
|
||||
value per channel, which we allow, before we start to modify
|
||||
the derivatives to prevent this.
|
||||
min_prob: determines the minimum probability with which we modify the
|
||||
prob: determines the minimum probability with which we modify the
|
||||
gradients for the {min,max}_positive and {min,max}_abs constraints,
|
||||
on each forward(). This is done randomly to prevent all layers
|
||||
from doing it at the same time. Early in training we may use
|
||||
higher probabilities than this; it will decay to this value.
|
||||
from doing it at the same time.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
@ -637,11 +636,11 @@ class ActivationBalancer(torch.nn.Module):
|
||||
|
||||
|
||||
if prob is None:
|
||||
prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1), default=0.4)
|
||||
prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
|
||||
self.prob = prob
|
||||
# 10% of the time we will return and do nothing because memory usage
|
||||
# is too high.
|
||||
self.mem_cutoff = CutoffEstimator(0.1)
|
||||
# 20% of the time we will return and do nothing because memory usage is
|
||||
# too high.
|
||||
self.mem_cutoff = CutoffEstimator(0.2)
|
||||
|
||||
# actually self.num_channels is no longer needed except for an assertion.
|
||||
self.num_channels = num_channels
|
||||
@ -838,9 +837,9 @@ class Whiten(nn.Module):
|
||||
self.whitening_limit = whitening_limit
|
||||
self.grad_scale = grad_scale
|
||||
|
||||
# 10% of the time we will return and do nothing because memory usage
|
||||
# 20% of the time we will return and do nothing because memory usage
|
||||
# is too high.
|
||||
self.mem_cutoff = CutoffEstimator(0.1)
|
||||
self.mem_cutoff = CutoffEstimator(0.2)
|
||||
|
||||
if isinstance(prob, float):
|
||||
assert 0 < prob <= 1
|
||||
|
||||
@ -1799,7 +1799,10 @@ class Conv2dSubsampling(nn.Module):
|
||||
)
|
||||
|
||||
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels),
|
||||
ConvNeXt(layer2_channels))
|
||||
ConvNeXt(layer2_channels),
|
||||
BasicNorm(layer2_channels,
|
||||
channel_dim=1))
|
||||
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
@ -1815,7 +1818,9 @@ class Conv2dSubsampling(nn.Module):
|
||||
)
|
||||
|
||||
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels),
|
||||
ConvNeXt(layer3_channels))
|
||||
ConvNeXt(layer3_channels),
|
||||
BasicNorm(layer3_channels,
|
||||
channel_dim=1))
|
||||
|
||||
|
||||
out_height = (((in_channels - 1) // 2) - 1) // 2
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user