diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 61171be00..8a9867103 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -235,8 +235,8 @@ class CutoffEstimator: """ Estimates cutoffs of an arbitrary numerical quantity such that a specified proportion of items will be above the cutoff on average. - p is the proportion of items that should be above the cutoff. + p is the proportion of items that should be above the cutoff. """ def __init__(self, p: float): self.p = p @@ -614,11 +614,10 @@ class ActivationBalancer(torch.nn.Module): max_abs: the maximum average-absolute-value difference from the mean value per channel, which we allow, before we start to modify the derivatives to prevent this. - min_prob: determines the minimum probability with which we modify the + prob: determines the minimum probability with which we modify the gradients for the {min,max}_positive and {min,max}_abs constraints, on each forward(). This is done randomly to prevent all layers - from doing it at the same time. Early in training we may use - higher probabilities than this; it will decay to this value. + from doing it at the same time. """ def __init__( self, @@ -637,11 +636,11 @@ class ActivationBalancer(torch.nn.Module): if prob is None: - prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1), default=0.4) + prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) self.prob = prob - # 10% of the time we will return and do nothing because memory usage - # is too high. - self.mem_cutoff = CutoffEstimator(0.1) + # 20% of the time we will return and do nothing because memory usage is + # too high. + self.mem_cutoff = CutoffEstimator(0.2) # actually self.num_channels is no longer needed except for an assertion. self.num_channels = num_channels @@ -838,9 +837,9 @@ class Whiten(nn.Module): self.whitening_limit = whitening_limit self.grad_scale = grad_scale - # 10% of the time we will return and do nothing because memory usage + # 20% of the time we will return and do nothing because memory usage # is too high. - self.mem_cutoff = CutoffEstimator(0.1) + self.mem_cutoff = CutoffEstimator(0.2) if isinstance(prob, float): assert 0 < prob <= 1 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 197362c9d..0d6497326 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1799,7 +1799,10 @@ class Conv2dSubsampling(nn.Module): ) self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels), - ConvNeXt(layer2_channels)) + ConvNeXt(layer2_channels), + BasicNorm(layer2_channels, + channel_dim=1)) + self.conv2 = nn.Sequential( nn.Conv2d( @@ -1815,7 +1818,9 @@ class Conv2dSubsampling(nn.Module): ) self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels), - ConvNeXt(layer3_channels)) + ConvNeXt(layer3_channels), + BasicNorm(layer3_channels, + channel_dim=1)) out_height = (((in_channels - 1) // 2) - 1) // 2