Have warmup schedule for layer-skipping
This commit is contained in:
parent
072776b2a1
commit
a3561c8dcd
@ -84,6 +84,15 @@ class Zipformer(EncoderInterface):
|
|||||||
self.zipformer_downsampling_factors = zipformer_downsampling_factors
|
self.zipformer_downsampling_factors = zipformer_downsampling_factors
|
||||||
self.output_downsampling_factor = output_downsampling_factor
|
self.output_downsampling_factor = output_downsampling_factor
|
||||||
|
|
||||||
|
# keep track of how many times forward() has been called, for purposes
|
||||||
|
# of warmup. do this with a floating-point count because integer counts
|
||||||
|
# can fail to survive model averaging. initialize with a smallish
|
||||||
|
# random number so that different encoders use different random seeds in
|
||||||
|
# shared_rng get_layers_to_drop() while using the same random seeds
|
||||||
|
# across jobs.
|
||||||
|
self.register_buffer('warmup_count', torch.tensor(float(10.0 * random.random())))
|
||||||
|
self.warmup_end = warmup_batches
|
||||||
|
|
||||||
for u,d in zip(encoder_unmasked_dims, encoder_dims):
|
for u,d in zip(encoder_unmasked_dims, encoder_dims):
|
||||||
assert u <= d
|
assert u <= d
|
||||||
|
|
||||||
@ -138,6 +147,35 @@ class Zipformer(EncoderInterface):
|
|||||||
encoder_dims[-1],
|
encoder_dims[-1],
|
||||||
downsample=output_downsampling_factor)
|
downsample=output_downsampling_factor)
|
||||||
|
|
||||||
|
|
||||||
|
def get_warmup_count(self) -> float:
|
||||||
|
"""
|
||||||
|
Returns a value that reflects how many times this function has been called in training mode.
|
||||||
|
"""
|
||||||
|
ans = self.warmup_count.item()
|
||||||
|
if self.training:
|
||||||
|
if ans > 1000000.0:
|
||||||
|
# this ensures that as the number of batches gets large, the warmup count cycles rather
|
||||||
|
# than getting stuck at the smallest floating point value x such that x + 1 == x.
|
||||||
|
# this is necessary because get_layers_to_drop() relies on the warmup count changing
|
||||||
|
# on every batch.
|
||||||
|
next_count = 500000.0
|
||||||
|
else:
|
||||||
|
next_count = ans + 1.0
|
||||||
|
self.warmup_count.fill_(next_count)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def _get_layer_skip_dropout_prob(self):
|
||||||
|
if not self.training:
|
||||||
|
return 0.0
|
||||||
|
warmup_count = self.get_warmup_count()
|
||||||
|
min_dropout_prob = 0.025
|
||||||
|
|
||||||
|
if warmup_count > self.warmup_end:
|
||||||
|
return min_dropout_prob
|
||||||
|
else:
|
||||||
|
return 0.5 - (warmup_count / self.warmup_end) * (0.5 - min_dropout_prob)
|
||||||
|
|
||||||
def _init_skip_modules(self):
|
def _init_skip_modules(self):
|
||||||
"""
|
"""
|
||||||
If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer
|
If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer
|
||||||
@ -256,7 +294,7 @@ class Zipformer(EncoderInterface):
|
|||||||
for i, module in enumerate(self.encoders):
|
for i, module in enumerate(self.encoders):
|
||||||
ds = self.zipformer_downsampling_factors[i]
|
ds = self.zipformer_downsampling_factors[i]
|
||||||
if self.skip_layers[i] is not None:
|
if self.skip_layers[i] is not None:
|
||||||
layer_skip_dropout_prob = 0.05
|
layer_skip_dropout_prob = self._get_layer_skip_dropout_prob()
|
||||||
if (not self.training) or random.random() > layer_skip_dropout_prob:
|
if (not self.training) or random.random() > layer_skip_dropout_prob:
|
||||||
x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
||||||
x = module(x,
|
x = module(x,
|
||||||
|
|||||||
Reference in New Issue
Block a user