mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Try to refactor the code for scheduling
This commit is contained in:
parent
aa0b1a37cd
commit
cd4730b657
@ -981,6 +981,52 @@ class DoubleSwish(torch.nn.Module):
|
|||||||
return DoubleSwishFunction.apply(x)
|
return DoubleSwishFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduledFloat(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
||||||
|
it does not have a working forward() function. You are supposed to cast it to float, as
|
||||||
|
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
||||||
|
|
||||||
|
It is a floating point value whose value changes depending on the batch count of the
|
||||||
|
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
|
||||||
|
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
||||||
|
first x or after the last x, we just use the first or last y value.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0))
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
*args):
|
||||||
|
super().__init__()
|
||||||
|
# self.batch_count will be written to in the training loop.
|
||||||
|
self.batch_count = 0
|
||||||
|
assert len(args) >= 1
|
||||||
|
for (x,y) in args:
|
||||||
|
assert x >= 0
|
||||||
|
for i in range(len(args) - 1):
|
||||||
|
assert args[i + 1] > args[i], args
|
||||||
|
self.schedule = args
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return 'batch_count={}, schedule={}'.format(self.batch_count,
|
||||||
|
self.schedule)
|
||||||
|
|
||||||
|
def __float__(self):
|
||||||
|
batch_count = self.batch_count
|
||||||
|
if batch_count <= self.schedule[0][0]:
|
||||||
|
return self.schedule[0][1]
|
||||||
|
elif batch_count >= self.schedule[-1][0]:
|
||||||
|
return self.schedule[-1][1]
|
||||||
|
else:
|
||||||
|
cur_x, cur_y = self.schedule[0]
|
||||||
|
for i in range(1, len(self.schedule)):
|
||||||
|
next_x, next_y = self.schedule[i]
|
||||||
|
if batch_count >= cur_x and batch_count <= next_x:
|
||||||
|
return cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
||||||
|
assert False
|
||||||
|
|
||||||
|
FloatLike = Union[float, ScheduledFloat]
|
||||||
|
|
||||||
|
|
||||||
def _test_max_eig():
|
def _test_max_eig():
|
||||||
for proportion in [0.1, 0.5, 10.0]:
|
for proportion in [0.1, 0.5, 10.0]:
|
||||||
|
|||||||
@ -37,6 +37,8 @@ from scaling import (
|
|||||||
random_clamp,
|
random_clamp,
|
||||||
penalize_abs_values_gt,
|
penalize_abs_values_gt,
|
||||||
softmax,
|
softmax,
|
||||||
|
ScheduledFloat,
|
||||||
|
FloatLike,
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -104,6 +106,13 @@ class Zipformer(EncoderInterface):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super(Zipformer, self).__init__()
|
super(Zipformer, self).__init__()
|
||||||
|
|
||||||
|
# this is not the probability of skipping a layer. It is the probability of
|
||||||
|
# dropping out the "skip module" which allows the model to skip groups of
|
||||||
|
# encoder stacks; when it's dropped out like this, it means we are forced
|
||||||
|
# to take the "direct" (non-bypass) path.
|
||||||
|
self.layer_skip_dropout_prob = ScheduledFloat((0.0, 0.5),
|
||||||
|
(warmup_batches, 0.025))
|
||||||
|
|
||||||
def _to_tuple(x):
|
def _to_tuple(x):
|
||||||
""" Converts a single int or a 1-tuple of an int to a tuple with the same length
|
""" Converts a single int or a 1-tuple of an int to a tuple with the same length
|
||||||
as downsampling_factor"""
|
as downsampling_factor"""
|
||||||
@ -128,9 +137,6 @@ class Zipformer(EncoderInterface):
|
|||||||
feedforward_dim = _to_tuple(feedforward_dim)
|
feedforward_dim = _to_tuple(feedforward_dim)
|
||||||
cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
||||||
|
|
||||||
# will be written to in training loop, see set_batch_count()
|
|
||||||
self.batch_count = 0
|
|
||||||
self.warmup_end = warmup_batches
|
|
||||||
|
|
||||||
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
||||||
assert u <= d
|
assert u <= d
|
||||||
@ -193,17 +199,6 @@ class Zipformer(EncoderInterface):
|
|||||||
downsample=output_downsampling_factor)
|
downsample=output_downsampling_factor)
|
||||||
|
|
||||||
|
|
||||||
def _get_layer_skip_dropout_prob(self):
|
|
||||||
if not self.training:
|
|
||||||
return 0.0
|
|
||||||
batch_count = self.batch_count
|
|
||||||
min_dropout_prob = 0.025
|
|
||||||
|
|
||||||
if batch_count > self.warmup_end:
|
|
||||||
return min_dropout_prob
|
|
||||||
else:
|
|
||||||
return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob)
|
|
||||||
|
|
||||||
def _init_skip_modules(self):
|
def _init_skip_modules(self):
|
||||||
"""
|
"""
|
||||||
If self.downampling_factor = (1, 2, 4, 8, 4, 2), then at the input of layer
|
If self.downampling_factor = (1, 2, 4, 8, 4, 2), then at the input of layer
|
||||||
@ -321,8 +316,10 @@ class Zipformer(EncoderInterface):
|
|||||||
for i, module in enumerate(self.encoders):
|
for i, module in enumerate(self.encoders):
|
||||||
ds = self.downsampling_factor[i]
|
ds = self.downsampling_factor[i]
|
||||||
if self.skip_layers[i] is not None:
|
if self.skip_layers[i] is not None:
|
||||||
layer_skip_dropout_prob = self._get_layer_skip_dropout_prob()
|
# this how we implement U-net-like skipping of some series of
|
||||||
if (not self.training) or random.random() > layer_skip_dropout_prob:
|
# stacks. The layer_skip_dropout_prob is to discourage it, especially
|
||||||
|
# early in training, from completely ignoring the middle layers.
|
||||||
|
if not (self.training and random.random() < float(self.layer_skip_dropout_prob)):
|
||||||
x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
||||||
x = module(x,
|
x = module(x,
|
||||||
feature_mask=feature_masks[i],
|
feature_mask=feature_masks[i],
|
||||||
@ -365,12 +362,24 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
feedforward_dim: int,
|
feedforward_dim: int,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
|
# layer_skip_prob will be overwritten to change warmup begin and end times.
|
||||||
|
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05)),
|
||||||
|
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0)),
|
||||||
|
bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25)),
|
||||||
|
bypass_clamp_max: FloatLike = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ZipformerEncoderLayer, self).__init__()
|
super(ZipformerEncoderLayer, self).__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
# will be written to in training loop, see set_batch_count()
|
# probability of skipping the entire layer.
|
||||||
self.batch_count = 0
|
self.layer_skip_prob = copy.deepcopy(layer_skip_prob)
|
||||||
|
# skip probability for dynamic modules (meaning: anything but feedforward)
|
||||||
|
self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob)
|
||||||
|
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
||||||
|
# ever becoming zero.
|
||||||
|
self.bypass_clamp_min = copy.deepcopy(bypass_clamp_min)
|
||||||
|
self.bypass_clamp_max = copy.deepcopy(bypass_clamp_max)
|
||||||
|
|
||||||
|
|
||||||
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
|
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
|
||||||
embed_dim, pos_dim=pos_dim, num_heads=num_heads,
|
embed_dim, pos_dim=pos_dim, num_heads=num_heads,
|
||||||
@ -424,36 +433,13 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
def get_bypass_scale(self):
|
def get_bypass_scale(self):
|
||||||
if torch.jit.is_scripting() or not self.training:
|
if torch.jit.is_scripting() or not self.training or random.random() < 0.5:
|
||||||
|
# the random.random() part is to ensure we get grads if self.bypass_scale becomes out of range
|
||||||
return self.bypass_scale
|
return self.bypass_scale
|
||||||
if random.random() < 0.1:
|
|
||||||
# ensure we get grads if self.bypass_scale becomes out of range
|
|
||||||
return self.bypass_scale
|
|
||||||
# hardcode warmup period for bypass scale
|
|
||||||
warmup_period = 20000.0
|
|
||||||
initial_clamp_min = 0.75
|
|
||||||
final_clamp_min = 0.25
|
|
||||||
if self.batch_count > warmup_period:
|
|
||||||
clamp_min = final_clamp_min
|
|
||||||
else:
|
|
||||||
clamp_min = (initial_clamp_min -
|
|
||||||
(self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min))
|
|
||||||
return self.bypass_scale.clamp(min=clamp_min, max=1.0)
|
|
||||||
|
|
||||||
def get_dynamic_dropout_rate(self):
|
return self.bypass_scale.clamp(min=float(self.bypass_clamp_min),
|
||||||
# return dropout rate for the dynamic modules (self_attn, pooling, convolution); this
|
max=float(self.bypass_clamp_max))
|
||||||
# starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable
|
|
||||||
# at the beginning, by making the network focus on the feedforward modules.
|
|
||||||
if torch.jit.is_scripting() or not self.training:
|
|
||||||
return 0.0
|
|
||||||
warmup_period = 2000.0
|
|
||||||
initial_dropout_rate = 0.2
|
|
||||||
final_dropout_rate = 0.0
|
|
||||||
if self.batch_count > warmup_period:
|
|
||||||
return final_dropout_rate
|
|
||||||
else:
|
|
||||||
return (initial_dropout_rate -
|
|
||||||
(initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period))
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -479,18 +465,20 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
"""
|
"""
|
||||||
|
if self.training and random.random() < float(self.layer_skip_prob):
|
||||||
|
# skip the layer
|
||||||
|
return src
|
||||||
|
|
||||||
src_orig = src
|
src_orig = src
|
||||||
|
|
||||||
# macaron style feed forward module
|
# macaron style feed forward module
|
||||||
src = src + self.feed_forward1(src)
|
src = src + self.feed_forward1(src)
|
||||||
|
|
||||||
# dropout rate for submodules that interact with time.
|
# dropout rate for non-feedforward submodules
|
||||||
dynamic_dropout = self.get_dynamic_dropout_rate()
|
dynamic_skip_prob = float(self.dynamic_skip_prob) if self.training else 0.0
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
# TODO: make the various attention-using models be dropped
|
use_self_attn = (random.random() >= dynamic_skip_prob)
|
||||||
# out independently.
|
|
||||||
use_self_attn = (random.random() > dynamic_dropout)
|
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||||
attn_weights = self.self_attn_weights(
|
attn_weights = self.self_attn_weights(
|
||||||
@ -519,7 +507,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.self_attn2(
|
src = src + self.self_attn2(
|
||||||
src, attn_weights)
|
src, attn_weights)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob:
|
||||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||||
|
|
||||||
src = src + self.feed_forward3(src)
|
src = src + self.feed_forward3(src)
|
||||||
@ -559,19 +547,15 @@ class ZipformerEncoder(nn.Module):
|
|||||||
pos_dim: int,
|
pos_dim: int,
|
||||||
dropout: float,
|
dropout: float,
|
||||||
warmup_begin: float,
|
warmup_begin: float,
|
||||||
warmup_end: float
|
warmup_end: float,
|
||||||
|
initial_layerdrop_prob: float = 0.5,
|
||||||
|
final_layerdrop_prob: float = 0.05,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# will be written to, see set_batch_count() Note: in inference time this
|
# will be written to, see set_batch_count() Note: in inference time this
|
||||||
# may be zero but should be treated as large, we can check if
|
# may be zero but should be treated as large, we can check if
|
||||||
# self.training is true.
|
# self.training is true.
|
||||||
self.batch_count = 0
|
self.batch_count = 0
|
||||||
self.warmup_begin = warmup_begin
|
|
||||||
self.warmup_end = warmup_end
|
|
||||||
# module_seed is for when we need a random number that is unique to the module but
|
|
||||||
# shared across jobs. It's used to randomly select how many layers to drop,
|
|
||||||
# so that we can keep this consistent across worker tasks (for efficiency).
|
|
||||||
self.module_seed = torch.randint(0, 1000, ()).item()
|
|
||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(pos_dim, dropout_rate=0.15)
|
self.encoder_pos = RelPositionalEncoding(pos_dim, dropout_rate=0.15)
|
||||||
|
|
||||||
@ -582,75 +566,13 @@ class ZipformerEncoder(nn.Module):
|
|||||||
|
|
||||||
assert 0 <= warmup_begin <= warmup_end
|
assert 0 <= warmup_begin <= warmup_end
|
||||||
|
|
||||||
|
|
||||||
delta = (1. / num_layers) * (warmup_end - warmup_begin)
|
delta = (1. / num_layers) * (warmup_end - warmup_begin)
|
||||||
cur_begin = warmup_begin
|
cur_begin = warmup_begin # interpreted as a training batch index
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
self.layers[i].warmup_begin = cur_begin
|
cur_end = cur_begin + delta
|
||||||
cur_begin += delta
|
self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob),
|
||||||
self.layers[i].warmup_end = cur_begin
|
(cur_end, final_layerdrop_prob))
|
||||||
|
cur_begin = cur_end
|
||||||
|
|
||||||
def get_layers_to_drop(self, rnd_seed: int):
|
|
||||||
ans = set()
|
|
||||||
if not self.training:
|
|
||||||
return ans
|
|
||||||
|
|
||||||
batch_count = self.batch_count
|
|
||||||
num_layers = len(self.layers)
|
|
||||||
|
|
||||||
def get_layerdrop_prob(layer: int) -> float:
|
|
||||||
layer_warmup_begin = self.layers[layer].warmup_begin
|
|
||||||
layer_warmup_end = self.layers[layer].warmup_end
|
|
||||||
|
|
||||||
initial_layerdrop_prob = 0.5
|
|
||||||
final_layerdrop_prob = 0.05
|
|
||||||
|
|
||||||
if batch_count == 0:
|
|
||||||
# As a special case, if batch_count == 0, return 0 (drop no
|
|
||||||
# layers). This is rather ugly, I'm afraid; it is intended to
|
|
||||||
# enable our scan_pessimistic_batches_for_oom() code to work correctly
|
|
||||||
# so if we are going to get OOM it will happen early.
|
|
||||||
# also search for 'batch_count' with quotes in this file to see
|
|
||||||
# how we initialize the warmup count to a random number between
|
|
||||||
# 0 and 10.
|
|
||||||
return 0.0
|
|
||||||
elif batch_count < layer_warmup_begin:
|
|
||||||
return initial_layerdrop_prob
|
|
||||||
elif batch_count > layer_warmup_end:
|
|
||||||
return final_layerdrop_prob
|
|
||||||
else:
|
|
||||||
# linearly interpolate
|
|
||||||
t = (batch_count - layer_warmup_begin) / layer_warmup_end
|
|
||||||
assert 0.0 <= t < 1.001
|
|
||||||
return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob)
|
|
||||||
|
|
||||||
shared_rng = random.Random(batch_count + self.module_seed)
|
|
||||||
independent_rng = random.Random(rnd_seed)
|
|
||||||
|
|
||||||
layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ]
|
|
||||||
tot = sum(layerdrop_probs)
|
|
||||||
# Instead of drawing the samples independently, we first randomly decide
|
|
||||||
# how many layers to drop out, using the same random number generator between
|
|
||||||
# jobs so that all jobs drop out the same number (this is for speed).
|
|
||||||
# Then we use an approximate approach to drop out the individual layers
|
|
||||||
# with their specified probs while reaching this exact target.
|
|
||||||
num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot)))
|
|
||||||
|
|
||||||
layers = list(range(num_layers))
|
|
||||||
independent_rng.shuffle(layers)
|
|
||||||
|
|
||||||
# go through the shuffled layers until we get the required number of samples.
|
|
||||||
if num_to_drop > 0:
|
|
||||||
for layer in itertools.cycle(layers):
|
|
||||||
if independent_rng.random() < layerdrop_probs[layer]:
|
|
||||||
ans.add(layer)
|
|
||||||
if len(ans) == num_to_drop:
|
|
||||||
break
|
|
||||||
if shared_rng.random() < 0.005 or __name__ == "__main__":
|
|
||||||
logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, "
|
|
||||||
f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -683,13 +605,10 @@ class ZipformerEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
rnd_seed = src.numel() + random.randint(0, 1000)
|
rnd_seed = src.numel() + random.randint(0, 1000)
|
||||||
layers_to_drop = self.get_layers_to_drop(rnd_seed)
|
|
||||||
|
|
||||||
output = output * feature_mask
|
output = output * feature_mask
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
if i in layers_to_drop:
|
|
||||||
continue
|
|
||||||
output = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
@ -942,7 +861,7 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
weight1 = self.weight1
|
weight1 = self.weight1
|
||||||
if self.training and random.random() < 0.25 and self.min_weight != (0., 0.):
|
if self.training and random.random() < 0.5 and self.min_weight != (0., 0.):
|
||||||
weight1 = weight1.clamp(min=self.min_weight[0],
|
weight1 = weight1.clamp(min=self.min_weight[0],
|
||||||
max=1.0-self.min_weight[1])
|
max=1.0-self.min_weight[1])
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user