mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Try to refactor the code for scheduling
This commit is contained in:
parent
aa0b1a37cd
commit
cd4730b657
@ -981,6 +981,52 @@ class DoubleSwish(torch.nn.Module):
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
|
||||
class ScheduledFloat(torch.nn.Module):
|
||||
"""
|
||||
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
||||
it does not have a working forward() function. You are supposed to cast it to float, as
|
||||
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
||||
|
||||
It is a floating point value whose value changes depending on the batch count of the
|
||||
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
|
||||
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
||||
first x or after the last x, we just use the first or last y value.
|
||||
|
||||
Example:
|
||||
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0))
|
||||
"""
|
||||
def __init__(self,
|
||||
*args):
|
||||
super().__init__()
|
||||
# self.batch_count will be written to in the training loop.
|
||||
self.batch_count = 0
|
||||
assert len(args) >= 1
|
||||
for (x,y) in args:
|
||||
assert x >= 0
|
||||
for i in range(len(args) - 1):
|
||||
assert args[i + 1] > args[i], args
|
||||
self.schedule = args
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'batch_count={}, schedule={}'.format(self.batch_count,
|
||||
self.schedule)
|
||||
|
||||
def __float__(self):
|
||||
batch_count = self.batch_count
|
||||
if batch_count <= self.schedule[0][0]:
|
||||
return self.schedule[0][1]
|
||||
elif batch_count >= self.schedule[-1][0]:
|
||||
return self.schedule[-1][1]
|
||||
else:
|
||||
cur_x, cur_y = self.schedule[0]
|
||||
for i in range(1, len(self.schedule)):
|
||||
next_x, next_y = self.schedule[i]
|
||||
if batch_count >= cur_x and batch_count <= next_x:
|
||||
return cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
||||
assert False
|
||||
|
||||
FloatLike = Union[float, ScheduledFloat]
|
||||
|
||||
|
||||
def _test_max_eig():
|
||||
for proportion in [0.1, 0.5, 10.0]:
|
||||
|
||||
@ -37,6 +37,8 @@ from scaling import (
|
||||
random_clamp,
|
||||
penalize_abs_values_gt,
|
||||
softmax,
|
||||
ScheduledFloat,
|
||||
FloatLike,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -104,6 +106,13 @@ class Zipformer(EncoderInterface):
|
||||
) -> None:
|
||||
super(Zipformer, self).__init__()
|
||||
|
||||
# this is not the probability of skipping a layer. It is the probability of
|
||||
# dropping out the "skip module" which allows the model to skip groups of
|
||||
# encoder stacks; when it's dropped out like this, it means we are forced
|
||||
# to take the "direct" (non-bypass) path.
|
||||
self.layer_skip_dropout_prob = ScheduledFloat((0.0, 0.5),
|
||||
(warmup_batches, 0.025))
|
||||
|
||||
def _to_tuple(x):
|
||||
""" Converts a single int or a 1-tuple of an int to a tuple with the same length
|
||||
as downsampling_factor"""
|
||||
@ -128,9 +137,6 @@ class Zipformer(EncoderInterface):
|
||||
feedforward_dim = _to_tuple(feedforward_dim)
|
||||
cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
||||
|
||||
# will be written to in training loop, see set_batch_count()
|
||||
self.batch_count = 0
|
||||
self.warmup_end = warmup_batches
|
||||
|
||||
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
||||
assert u <= d
|
||||
@ -193,17 +199,6 @@ class Zipformer(EncoderInterface):
|
||||
downsample=output_downsampling_factor)
|
||||
|
||||
|
||||
def _get_layer_skip_dropout_prob(self):
|
||||
if not self.training:
|
||||
return 0.0
|
||||
batch_count = self.batch_count
|
||||
min_dropout_prob = 0.025
|
||||
|
||||
if batch_count > self.warmup_end:
|
||||
return min_dropout_prob
|
||||
else:
|
||||
return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob)
|
||||
|
||||
def _init_skip_modules(self):
|
||||
"""
|
||||
If self.downampling_factor = (1, 2, 4, 8, 4, 2), then at the input of layer
|
||||
@ -321,8 +316,10 @@ class Zipformer(EncoderInterface):
|
||||
for i, module in enumerate(self.encoders):
|
||||
ds = self.downsampling_factor[i]
|
||||
if self.skip_layers[i] is not None:
|
||||
layer_skip_dropout_prob = self._get_layer_skip_dropout_prob()
|
||||
if (not self.training) or random.random() > layer_skip_dropout_prob:
|
||||
# this how we implement U-net-like skipping of some series of
|
||||
# stacks. The layer_skip_dropout_prob is to discourage it, especially
|
||||
# early in training, from completely ignoring the middle layers.
|
||||
if not (self.training and random.random() < float(self.layer_skip_dropout_prob)):
|
||||
x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
||||
x = module(x,
|
||||
feature_mask=feature_masks[i],
|
||||
@ -365,12 +362,24 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
feedforward_dim: int,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
# layer_skip_prob will be overwritten to change warmup begin and end times.
|
||||
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05)),
|
||||
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0)),
|
||||
bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25)),
|
||||
bypass_clamp_max: FloatLike = 1.0,
|
||||
) -> None:
|
||||
super(ZipformerEncoderLayer, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
# will be written to in training loop, see set_batch_count()
|
||||
self.batch_count = 0
|
||||
# probability of skipping the entire layer.
|
||||
self.layer_skip_prob = copy.deepcopy(layer_skip_prob)
|
||||
# skip probability for dynamic modules (meaning: anything but feedforward)
|
||||
self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob)
|
||||
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
||||
# ever becoming zero.
|
||||
self.bypass_clamp_min = copy.deepcopy(bypass_clamp_min)
|
||||
self.bypass_clamp_max = copy.deepcopy(bypass_clamp_max)
|
||||
|
||||
|
||||
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
|
||||
embed_dim, pos_dim=pos_dim, num_heads=num_heads,
|
||||
@ -424,36 +433,13 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
grad_scale=0.01)
|
||||
|
||||
def get_bypass_scale(self):
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
if torch.jit.is_scripting() or not self.training or random.random() < 0.5:
|
||||
# the random.random() part is to ensure we get grads if self.bypass_scale becomes out of range
|
||||
return self.bypass_scale
|
||||
if random.random() < 0.1:
|
||||
# ensure we get grads if self.bypass_scale becomes out of range
|
||||
return self.bypass_scale
|
||||
# hardcode warmup period for bypass scale
|
||||
warmup_period = 20000.0
|
||||
initial_clamp_min = 0.75
|
||||
final_clamp_min = 0.25
|
||||
if self.batch_count > warmup_period:
|
||||
clamp_min = final_clamp_min
|
||||
else:
|
||||
clamp_min = (initial_clamp_min -
|
||||
(self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min))
|
||||
return self.bypass_scale.clamp(min=clamp_min, max=1.0)
|
||||
|
||||
def get_dynamic_dropout_rate(self):
|
||||
# return dropout rate for the dynamic modules (self_attn, pooling, convolution); this
|
||||
# starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable
|
||||
# at the beginning, by making the network focus on the feedforward modules.
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
return 0.0
|
||||
warmup_period = 2000.0
|
||||
initial_dropout_rate = 0.2
|
||||
final_dropout_rate = 0.0
|
||||
if self.batch_count > warmup_period:
|
||||
return final_dropout_rate
|
||||
else:
|
||||
return (initial_dropout_rate -
|
||||
(initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period))
|
||||
return self.bypass_scale.clamp(min=float(self.bypass_clamp_min),
|
||||
max=float(self.bypass_clamp_max))
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -479,18 +465,20 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
if self.training and random.random() < float(self.layer_skip_prob):
|
||||
# skip the layer
|
||||
return src
|
||||
|
||||
src_orig = src
|
||||
|
||||
# macaron style feed forward module
|
||||
src = src + self.feed_forward1(src)
|
||||
|
||||
# dropout rate for submodules that interact with time.
|
||||
dynamic_dropout = self.get_dynamic_dropout_rate()
|
||||
|
||||
# dropout rate for non-feedforward submodules
|
||||
dynamic_skip_prob = float(self.dynamic_skip_prob) if self.training else 0.0
|
||||
# multi-headed self-attention module
|
||||
# TODO: make the various attention-using models be dropped
|
||||
# out independently.
|
||||
use_self_attn = (random.random() > dynamic_dropout)
|
||||
use_self_attn = (random.random() >= dynamic_skip_prob)
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||
attn_weights = self.self_attn_weights(
|
||||
@ -519,7 +507,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
src = src + self.self_attn2(
|
||||
src, attn_weights)
|
||||
|
||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||
if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob:
|
||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
src = src + self.feed_forward3(src)
|
||||
@ -559,19 +547,15 @@ class ZipformerEncoder(nn.Module):
|
||||
pos_dim: int,
|
||||
dropout: float,
|
||||
warmup_begin: float,
|
||||
warmup_end: float
|
||||
warmup_end: float,
|
||||
initial_layerdrop_prob: float = 0.5,
|
||||
final_layerdrop_prob: float = 0.05,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# will be written to, see set_batch_count() Note: in inference time this
|
||||
# may be zero but should be treated as large, we can check if
|
||||
# self.training is true.
|
||||
self.batch_count = 0
|
||||
self.warmup_begin = warmup_begin
|
||||
self.warmup_end = warmup_end
|
||||
# module_seed is for when we need a random number that is unique to the module but
|
||||
# shared across jobs. It's used to randomly select how many layers to drop,
|
||||
# so that we can keep this consistent across worker tasks (for efficiency).
|
||||
self.module_seed = torch.randint(0, 1000, ()).item()
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(pos_dim, dropout_rate=0.15)
|
||||
|
||||
@ -582,75 +566,13 @@ class ZipformerEncoder(nn.Module):
|
||||
|
||||
assert 0 <= warmup_begin <= warmup_end
|
||||
|
||||
|
||||
delta = (1. / num_layers) * (warmup_end - warmup_begin)
|
||||
cur_begin = warmup_begin
|
||||
cur_begin = warmup_begin # interpreted as a training batch index
|
||||
for i in range(num_layers):
|
||||
self.layers[i].warmup_begin = cur_begin
|
||||
cur_begin += delta
|
||||
self.layers[i].warmup_end = cur_begin
|
||||
|
||||
|
||||
def get_layers_to_drop(self, rnd_seed: int):
|
||||
ans = set()
|
||||
if not self.training:
|
||||
return ans
|
||||
|
||||
batch_count = self.batch_count
|
||||
num_layers = len(self.layers)
|
||||
|
||||
def get_layerdrop_prob(layer: int) -> float:
|
||||
layer_warmup_begin = self.layers[layer].warmup_begin
|
||||
layer_warmup_end = self.layers[layer].warmup_end
|
||||
|
||||
initial_layerdrop_prob = 0.5
|
||||
final_layerdrop_prob = 0.05
|
||||
|
||||
if batch_count == 0:
|
||||
# As a special case, if batch_count == 0, return 0 (drop no
|
||||
# layers). This is rather ugly, I'm afraid; it is intended to
|
||||
# enable our scan_pessimistic_batches_for_oom() code to work correctly
|
||||
# so if we are going to get OOM it will happen early.
|
||||
# also search for 'batch_count' with quotes in this file to see
|
||||
# how we initialize the warmup count to a random number between
|
||||
# 0 and 10.
|
||||
return 0.0
|
||||
elif batch_count < layer_warmup_begin:
|
||||
return initial_layerdrop_prob
|
||||
elif batch_count > layer_warmup_end:
|
||||
return final_layerdrop_prob
|
||||
else:
|
||||
# linearly interpolate
|
||||
t = (batch_count - layer_warmup_begin) / layer_warmup_end
|
||||
assert 0.0 <= t < 1.001
|
||||
return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob)
|
||||
|
||||
shared_rng = random.Random(batch_count + self.module_seed)
|
||||
independent_rng = random.Random(rnd_seed)
|
||||
|
||||
layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ]
|
||||
tot = sum(layerdrop_probs)
|
||||
# Instead of drawing the samples independently, we first randomly decide
|
||||
# how many layers to drop out, using the same random number generator between
|
||||
# jobs so that all jobs drop out the same number (this is for speed).
|
||||
# Then we use an approximate approach to drop out the individual layers
|
||||
# with their specified probs while reaching this exact target.
|
||||
num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot)))
|
||||
|
||||
layers = list(range(num_layers))
|
||||
independent_rng.shuffle(layers)
|
||||
|
||||
# go through the shuffled layers until we get the required number of samples.
|
||||
if num_to_drop > 0:
|
||||
for layer in itertools.cycle(layers):
|
||||
if independent_rng.random() < layerdrop_probs[layer]:
|
||||
ans.add(layer)
|
||||
if len(ans) == num_to_drop:
|
||||
break
|
||||
if shared_rng.random() < 0.005 or __name__ == "__main__":
|
||||
logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, "
|
||||
f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
|
||||
return ans
|
||||
cur_end = cur_begin + delta
|
||||
self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob),
|
||||
(cur_end, final_layerdrop_prob))
|
||||
cur_begin = cur_end
|
||||
|
||||
|
||||
def forward(
|
||||
@ -683,13 +605,10 @@ class ZipformerEncoder(nn.Module):
|
||||
|
||||
|
||||
rnd_seed = src.numel() + random.randint(0, 1000)
|
||||
layers_to_drop = self.get_layers_to_drop(rnd_seed)
|
||||
|
||||
output = output * feature_mask
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
if i in layers_to_drop:
|
||||
continue
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
@ -942,7 +861,7 @@ class SimpleCombiner(torch.nn.Module):
|
||||
|
||||
|
||||
weight1 = self.weight1
|
||||
if self.training and random.random() < 0.25 and self.min_weight != (0., 0.):
|
||||
if self.training and random.random() < 0.5 and self.min_weight != (0., 0.):
|
||||
weight1 = weight1.clamp(min=self.min_weight[0],
|
||||
max=1.0-self.min_weight[1])
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user