Try to refactor the code for scheduling

This commit is contained in:
Daniel Povey 2022-11-14 12:50:24 +08:00
parent aa0b1a37cd
commit cd4730b657
2 changed files with 96 additions and 131 deletions

View File

@ -981,6 +981,52 @@ class DoubleSwish(torch.nn.Module):
return DoubleSwishFunction.apply(x)
class ScheduledFloat(torch.nn.Module):
"""
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
it does not have a working forward() function. You are supposed to cast it to float, as
in, float(parent_module.whatever), and use it as something like a dropout prob.
It is a floating point value whose value changes depending on the batch count of the
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
in sorted order on x; x corresponds to the batch index. For batch-index values before the
first x or after the last x, we just use the first or last y value.
Example:
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0))
"""
def __init__(self,
*args):
super().__init__()
# self.batch_count will be written to in the training loop.
self.batch_count = 0
assert len(args) >= 1
for (x,y) in args:
assert x >= 0
for i in range(len(args) - 1):
assert args[i + 1] > args[i], args
self.schedule = args
def extra_repr(self) -> str:
return 'batch_count={}, schedule={}'.format(self.batch_count,
self.schedule)
def __float__(self):
batch_count = self.batch_count
if batch_count <= self.schedule[0][0]:
return self.schedule[0][1]
elif batch_count >= self.schedule[-1][0]:
return self.schedule[-1][1]
else:
cur_x, cur_y = self.schedule[0]
for i in range(1, len(self.schedule)):
next_x, next_y = self.schedule[i]
if batch_count >= cur_x and batch_count <= next_x:
return cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
assert False
FloatLike = Union[float, ScheduledFloat]
def _test_max_eig():
for proportion in [0.1, 0.5, 10.0]:

View File

@ -37,6 +37,8 @@ from scaling import (
random_clamp,
penalize_abs_values_gt,
softmax,
ScheduledFloat,
FloatLike,
)
from torch import Tensor, nn
@ -104,6 +106,13 @@ class Zipformer(EncoderInterface):
) -> None:
super(Zipformer, self).__init__()
# this is not the probability of skipping a layer. It is the probability of
# dropping out the "skip module" which allows the model to skip groups of
# encoder stacks; when it's dropped out like this, it means we are forced
# to take the "direct" (non-bypass) path.
self.layer_skip_dropout_prob = ScheduledFloat((0.0, 0.5),
(warmup_batches, 0.025))
def _to_tuple(x):
""" Converts a single int or a 1-tuple of an int to a tuple with the same length
as downsampling_factor"""
@ -128,9 +137,6 @@ class Zipformer(EncoderInterface):
feedforward_dim = _to_tuple(feedforward_dim)
cnn_module_kernel = _to_tuple(cnn_module_kernel)
# will be written to in training loop, see set_batch_count()
self.batch_count = 0
self.warmup_end = warmup_batches
for u,d in zip(encoder_unmasked_dim, encoder_dim):
assert u <= d
@ -193,17 +199,6 @@ class Zipformer(EncoderInterface):
downsample=output_downsampling_factor)
def _get_layer_skip_dropout_prob(self):
if not self.training:
return 0.0
batch_count = self.batch_count
min_dropout_prob = 0.025
if batch_count > self.warmup_end:
return min_dropout_prob
else:
return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob)
def _init_skip_modules(self):
"""
If self.downampling_factor = (1, 2, 4, 8, 4, 2), then at the input of layer
@ -321,8 +316,10 @@ class Zipformer(EncoderInterface):
for i, module in enumerate(self.encoders):
ds = self.downsampling_factor[i]
if self.skip_layers[i] is not None:
layer_skip_dropout_prob = self._get_layer_skip_dropout_prob()
if (not self.training) or random.random() > layer_skip_dropout_prob:
# this how we implement U-net-like skipping of some series of
# stacks. The layer_skip_dropout_prob is to discourage it, especially
# early in training, from completely ignoring the middle layers.
if not (self.training and random.random() < float(self.layer_skip_dropout_prob)):
x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
x = module(x,
feature_mask=feature_masks[i],
@ -365,12 +362,24 @@ class ZipformerEncoderLayer(nn.Module):
feedforward_dim: int,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
# layer_skip_prob will be overwritten to change warmup begin and end times.
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05)),
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0)),
bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25)),
bypass_clamp_max: FloatLike = 1.0,
) -> None:
super(ZipformerEncoderLayer, self).__init__()
self.embed_dim = embed_dim
# will be written to in training loop, see set_batch_count()
self.batch_count = 0
# probability of skipping the entire layer.
self.layer_skip_prob = copy.deepcopy(layer_skip_prob)
# skip probability for dynamic modules (meaning: anything but feedforward)
self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob)
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
# ever becoming zero.
self.bypass_clamp_min = copy.deepcopy(bypass_clamp_min)
self.bypass_clamp_max = copy.deepcopy(bypass_clamp_max)
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
embed_dim, pos_dim=pos_dim, num_heads=num_heads,
@ -424,36 +433,13 @@ class ZipformerEncoderLayer(nn.Module):
grad_scale=0.01)
def get_bypass_scale(self):
if torch.jit.is_scripting() or not self.training:
if torch.jit.is_scripting() or not self.training or random.random() < 0.5:
# the random.random() part is to ensure we get grads if self.bypass_scale becomes out of range
return self.bypass_scale
if random.random() < 0.1:
# ensure we get grads if self.bypass_scale becomes out of range
return self.bypass_scale
# hardcode warmup period for bypass scale
warmup_period = 20000.0
initial_clamp_min = 0.75
final_clamp_min = 0.25
if self.batch_count > warmup_period:
clamp_min = final_clamp_min
else:
clamp_min = (initial_clamp_min -
(self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min))
return self.bypass_scale.clamp(min=clamp_min, max=1.0)
def get_dynamic_dropout_rate(self):
# return dropout rate for the dynamic modules (self_attn, pooling, convolution); this
# starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable
# at the beginning, by making the network focus on the feedforward modules.
if torch.jit.is_scripting() or not self.training:
return 0.0
warmup_period = 2000.0
initial_dropout_rate = 0.2
final_dropout_rate = 0.0
if self.batch_count > warmup_period:
return final_dropout_rate
else:
return (initial_dropout_rate -
(initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period))
return self.bypass_scale.clamp(min=float(self.bypass_clamp_min),
max=float(self.bypass_clamp_max))
def forward(
self,
@ -479,18 +465,20 @@ class ZipformerEncoderLayer(nn.Module):
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
if self.training and random.random() < float(self.layer_skip_prob):
# skip the layer
return src
src_orig = src
# macaron style feed forward module
src = src + self.feed_forward1(src)
# dropout rate for submodules that interact with time.
dynamic_dropout = self.get_dynamic_dropout_rate()
# dropout rate for non-feedforward submodules
dynamic_skip_prob = float(self.dynamic_skip_prob) if self.training else 0.0
# multi-headed self-attention module
# TODO: make the various attention-using models be dropped
# out independently.
use_self_attn = (random.random() > dynamic_dropout)
use_self_attn = (random.random() >= dynamic_skip_prob)
if torch.jit.is_scripting() or use_self_attn:
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
attn_weights = self.self_attn_weights(
@ -519,7 +507,7 @@ class ZipformerEncoderLayer(nn.Module):
src = src + self.self_attn2(
src, attn_weights)
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob:
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
src = src + self.feed_forward3(src)
@ -559,19 +547,15 @@ class ZipformerEncoder(nn.Module):
pos_dim: int,
dropout: float,
warmup_begin: float,
warmup_end: float
warmup_end: float,
initial_layerdrop_prob: float = 0.5,
final_layerdrop_prob: float = 0.05,
) -> None:
super().__init__()
# will be written to, see set_batch_count() Note: in inference time this
# may be zero but should be treated as large, we can check if
# self.training is true.
self.batch_count = 0
self.warmup_begin = warmup_begin
self.warmup_end = warmup_end
# module_seed is for when we need a random number that is unique to the module but
# shared across jobs. It's used to randomly select how many layers to drop,
# so that we can keep this consistent across worker tasks (for efficiency).
self.module_seed = torch.randint(0, 1000, ()).item()
self.encoder_pos = RelPositionalEncoding(pos_dim, dropout_rate=0.15)
@ -582,75 +566,13 @@ class ZipformerEncoder(nn.Module):
assert 0 <= warmup_begin <= warmup_end
delta = (1. / num_layers) * (warmup_end - warmup_begin)
cur_begin = warmup_begin
cur_begin = warmup_begin # interpreted as a training batch index
for i in range(num_layers):
self.layers[i].warmup_begin = cur_begin
cur_begin += delta
self.layers[i].warmup_end = cur_begin
def get_layers_to_drop(self, rnd_seed: int):
ans = set()
if not self.training:
return ans
batch_count = self.batch_count
num_layers = len(self.layers)
def get_layerdrop_prob(layer: int) -> float:
layer_warmup_begin = self.layers[layer].warmup_begin
layer_warmup_end = self.layers[layer].warmup_end
initial_layerdrop_prob = 0.5
final_layerdrop_prob = 0.05
if batch_count == 0:
# As a special case, if batch_count == 0, return 0 (drop no
# layers). This is rather ugly, I'm afraid; it is intended to
# enable our scan_pessimistic_batches_for_oom() code to work correctly
# so if we are going to get OOM it will happen early.
# also search for 'batch_count' with quotes in this file to see
# how we initialize the warmup count to a random number between
# 0 and 10.
return 0.0
elif batch_count < layer_warmup_begin:
return initial_layerdrop_prob
elif batch_count > layer_warmup_end:
return final_layerdrop_prob
else:
# linearly interpolate
t = (batch_count - layer_warmup_begin) / layer_warmup_end
assert 0.0 <= t < 1.001
return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob)
shared_rng = random.Random(batch_count + self.module_seed)
independent_rng = random.Random(rnd_seed)
layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ]
tot = sum(layerdrop_probs)
# Instead of drawing the samples independently, we first randomly decide
# how many layers to drop out, using the same random number generator between
# jobs so that all jobs drop out the same number (this is for speed).
# Then we use an approximate approach to drop out the individual layers
# with their specified probs while reaching this exact target.
num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot)))
layers = list(range(num_layers))
independent_rng.shuffle(layers)
# go through the shuffled layers until we get the required number of samples.
if num_to_drop > 0:
for layer in itertools.cycle(layers):
if independent_rng.random() < layerdrop_probs[layer]:
ans.add(layer)
if len(ans) == num_to_drop:
break
if shared_rng.random() < 0.005 or __name__ == "__main__":
logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, "
f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
return ans
cur_end = cur_begin + delta
self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob),
(cur_end, final_layerdrop_prob))
cur_begin = cur_end
def forward(
@ -683,13 +605,10 @@ class ZipformerEncoder(nn.Module):
rnd_seed = src.numel() + random.randint(0, 1000)
layers_to_drop = self.get_layers_to_drop(rnd_seed)
output = output * feature_mask
for i, mod in enumerate(self.layers):
if i in layers_to_drop:
continue
output = mod(
output,
pos_emb,
@ -942,7 +861,7 @@ class SimpleCombiner(torch.nn.Module):
weight1 = self.weight1
if self.training and random.random() < 0.25 and self.min_weight != (0., 0.):
if self.training and random.random() < 0.5 and self.min_weight != (0., 0.):
weight1 = weight1.clamp(min=self.min_weight[0],
max=1.0-self.min_weight[1])