diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index b9322f013..e04e5aa0c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -981,6 +981,52 @@ class DoubleSwish(torch.nn.Module): return DoubleSwishFunction.apply(x) +class ScheduledFloat(torch.nn.Module): + """ + This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); + it does not have a working forward() function. You are supposed to cast it to float, as + in, float(parent_module.whatever), and use it as something like a dropout prob. + + It is a floating point value whose value changes depending on the batch count of the + training loop. It is a piecewise linear function where you specifiy the (x,y) pairs + in sorted order on x; x corresponds to the batch index. For batch-index values before the + first x or after the last x, we just use the first or last y value. + + Example: + self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0)) + """ + def __init__(self, + *args): + super().__init__() + # self.batch_count will be written to in the training loop. + self.batch_count = 0 + assert len(args) >= 1 + for (x,y) in args: + assert x >= 0 + for i in range(len(args) - 1): + assert args[i + 1] > args[i], args + self.schedule = args + + def extra_repr(self) -> str: + return 'batch_count={}, schedule={}'.format(self.batch_count, + self.schedule) + + def __float__(self): + batch_count = self.batch_count + if batch_count <= self.schedule[0][0]: + return self.schedule[0][1] + elif batch_count >= self.schedule[-1][0]: + return self.schedule[-1][1] + else: + cur_x, cur_y = self.schedule[0] + for i in range(1, len(self.schedule)): + next_x, next_y = self.schedule[i] + if batch_count >= cur_x and batch_count <= next_x: + return cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x) + assert False + +FloatLike = Union[float, ScheduledFloat] + def _test_max_eig(): for proportion in [0.1, 0.5, 10.0]: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 31997f8c4..c4cd5dc9b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -37,6 +37,8 @@ from scaling import ( random_clamp, penalize_abs_values_gt, softmax, + ScheduledFloat, + FloatLike, ) from torch import Tensor, nn @@ -104,6 +106,13 @@ class Zipformer(EncoderInterface): ) -> None: super(Zipformer, self).__init__() + # this is not the probability of skipping a layer. It is the probability of + # dropping out the "skip module" which allows the model to skip groups of + # encoder stacks; when it's dropped out like this, it means we are forced + # to take the "direct" (non-bypass) path. + self.layer_skip_dropout_prob = ScheduledFloat((0.0, 0.5), + (warmup_batches, 0.025)) + def _to_tuple(x): """ Converts a single int or a 1-tuple of an int to a tuple with the same length as downsampling_factor""" @@ -128,9 +137,6 @@ class Zipformer(EncoderInterface): feedforward_dim = _to_tuple(feedforward_dim) cnn_module_kernel = _to_tuple(cnn_module_kernel) - # will be written to in training loop, see set_batch_count() - self.batch_count = 0 - self.warmup_end = warmup_batches for u,d in zip(encoder_unmasked_dim, encoder_dim): assert u <= d @@ -193,17 +199,6 @@ class Zipformer(EncoderInterface): downsample=output_downsampling_factor) - def _get_layer_skip_dropout_prob(self): - if not self.training: - return 0.0 - batch_count = self.batch_count - min_dropout_prob = 0.025 - - if batch_count > self.warmup_end: - return min_dropout_prob - else: - return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob) - def _init_skip_modules(self): """ If self.downampling_factor = (1, 2, 4, 8, 4, 2), then at the input of layer @@ -321,8 +316,10 @@ class Zipformer(EncoderInterface): for i, module in enumerate(self.encoders): ds = self.downsampling_factor[i] if self.skip_layers[i] is not None: - layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() - if (not self.training) or random.random() > layer_skip_dropout_prob: + # this how we implement U-net-like skipping of some series of + # stacks. The layer_skip_dropout_prob is to discourage it, especially + # early in training, from completely ignoring the middle layers. + if not (self.training and random.random() < float(self.layer_skip_dropout_prob)): x = self.skip_modules[i](outputs[self.skip_layers[i]], x) x = module(x, feature_mask=feature_masks[i], @@ -365,12 +362,24 @@ class ZipformerEncoderLayer(nn.Module): feedforward_dim: int, dropout: float = 0.1, cnn_module_kernel: int = 31, + # layer_skip_prob will be overwritten to change warmup begin and end times. + layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05)), + dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0)), + bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25)), + bypass_clamp_max: FloatLike = 1.0, ) -> None: super(ZipformerEncoderLayer, self).__init__() self.embed_dim = embed_dim - # will be written to in training loop, see set_batch_count() - self.batch_count = 0 + # probability of skipping the entire layer. + self.layer_skip_prob = copy.deepcopy(layer_skip_prob) + # skip probability for dynamic modules (meaning: anything but feedforward) + self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob) + # min and max for self.bypass_scale, applied with probability 0.5 to avoid grads + # ever becoming zero. + self.bypass_clamp_min = copy.deepcopy(bypass_clamp_min) + self.bypass_clamp_max = copy.deepcopy(bypass_clamp_max) + self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, pos_dim=pos_dim, num_heads=num_heads, @@ -424,36 +433,13 @@ class ZipformerEncoderLayer(nn.Module): grad_scale=0.01) def get_bypass_scale(self): - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or not self.training or random.random() < 0.5: + # the random.random() part is to ensure we get grads if self.bypass_scale becomes out of range return self.bypass_scale - if random.random() < 0.1: - # ensure we get grads if self.bypass_scale becomes out of range - return self.bypass_scale - # hardcode warmup period for bypass scale - warmup_period = 20000.0 - initial_clamp_min = 0.75 - final_clamp_min = 0.25 - if self.batch_count > warmup_period: - clamp_min = final_clamp_min - else: - clamp_min = (initial_clamp_min - - (self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min)) - return self.bypass_scale.clamp(min=clamp_min, max=1.0) - def get_dynamic_dropout_rate(self): - # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this - # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable - # at the beginning, by making the network focus on the feedforward modules. - if torch.jit.is_scripting() or not self.training: - return 0.0 - warmup_period = 2000.0 - initial_dropout_rate = 0.2 - final_dropout_rate = 0.0 - if self.batch_count > warmup_period: - return final_dropout_rate - else: - return (initial_dropout_rate - - (initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period)) + return self.bypass_scale.clamp(min=float(self.bypass_clamp_min), + max=float(self.bypass_clamp_max)) + def forward( self, @@ -479,18 +465,20 @@ class ZipformerEncoderLayer(nn.Module): src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number """ + if self.training and random.random() < float(self.layer_skip_prob): + # skip the layer + return src + src_orig = src # macaron style feed forward module src = src + self.feed_forward1(src) - # dropout rate for submodules that interact with time. - dynamic_dropout = self.get_dynamic_dropout_rate() - + # dropout rate for non-feedforward submodules + dynamic_skip_prob = float(self.dynamic_skip_prob) if self.training else 0.0 # multi-headed self-attention module - # TODO: make the various attention-using models be dropped - # out independently. - use_self_attn = (random.random() > dynamic_dropout) + use_self_attn = (random.random() >= dynamic_skip_prob) + if torch.jit.is_scripting() or use_self_attn: # attn_weights: (num_heads, batch_size, seq_len, seq_len) attn_weights = self.self_attn_weights( @@ -519,7 +507,7 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.self_attn2( src, attn_weights) - if torch.jit.is_scripting() or random.random() > dynamic_dropout: + if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob: src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward3(src) @@ -559,19 +547,15 @@ class ZipformerEncoder(nn.Module): pos_dim: int, dropout: float, warmup_begin: float, - warmup_end: float + warmup_end: float, + initial_layerdrop_prob: float = 0.5, + final_layerdrop_prob: float = 0.05, ) -> None: super().__init__() # will be written to, see set_batch_count() Note: in inference time this # may be zero but should be treated as large, we can check if # self.training is true. self.batch_count = 0 - self.warmup_begin = warmup_begin - self.warmup_end = warmup_end - # module_seed is for when we need a random number that is unique to the module but - # shared across jobs. It's used to randomly select how many layers to drop, - # so that we can keep this consistent across worker tasks (for efficiency). - self.module_seed = torch.randint(0, 1000, ()).item() self.encoder_pos = RelPositionalEncoding(pos_dim, dropout_rate=0.15) @@ -582,75 +566,13 @@ class ZipformerEncoder(nn.Module): assert 0 <= warmup_begin <= warmup_end - delta = (1. / num_layers) * (warmup_end - warmup_begin) - cur_begin = warmup_begin + cur_begin = warmup_begin # interpreted as a training batch index for i in range(num_layers): - self.layers[i].warmup_begin = cur_begin - cur_begin += delta - self.layers[i].warmup_end = cur_begin - - - def get_layers_to_drop(self, rnd_seed: int): - ans = set() - if not self.training: - return ans - - batch_count = self.batch_count - num_layers = len(self.layers) - - def get_layerdrop_prob(layer: int) -> float: - layer_warmup_begin = self.layers[layer].warmup_begin - layer_warmup_end = self.layers[layer].warmup_end - - initial_layerdrop_prob = 0.5 - final_layerdrop_prob = 0.05 - - if batch_count == 0: - # As a special case, if batch_count == 0, return 0 (drop no - # layers). This is rather ugly, I'm afraid; it is intended to - # enable our scan_pessimistic_batches_for_oom() code to work correctly - # so if we are going to get OOM it will happen early. - # also search for 'batch_count' with quotes in this file to see - # how we initialize the warmup count to a random number between - # 0 and 10. - return 0.0 - elif batch_count < layer_warmup_begin: - return initial_layerdrop_prob - elif batch_count > layer_warmup_end: - return final_layerdrop_prob - else: - # linearly interpolate - t = (batch_count - layer_warmup_begin) / layer_warmup_end - assert 0.0 <= t < 1.001 - return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob) - - shared_rng = random.Random(batch_count + self.module_seed) - independent_rng = random.Random(rnd_seed) - - layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ] - tot = sum(layerdrop_probs) - # Instead of drawing the samples independently, we first randomly decide - # how many layers to drop out, using the same random number generator between - # jobs so that all jobs drop out the same number (this is for speed). - # Then we use an approximate approach to drop out the individual layers - # with their specified probs while reaching this exact target. - num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot))) - - layers = list(range(num_layers)) - independent_rng.shuffle(layers) - - # go through the shuffled layers until we get the required number of samples. - if num_to_drop > 0: - for layer in itertools.cycle(layers): - if independent_rng.random() < layerdrop_probs[layer]: - ans.add(layer) - if len(ans) == num_to_drop: - break - if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " - f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") - return ans + cur_end = cur_begin + delta + self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob), + (cur_end, final_layerdrop_prob)) + cur_begin = cur_end def forward( @@ -683,13 +605,10 @@ class ZipformerEncoder(nn.Module): rnd_seed = src.numel() + random.randint(0, 1000) - layers_to_drop = self.get_layers_to_drop(rnd_seed) output = output * feature_mask for i, mod in enumerate(self.layers): - if i in layers_to_drop: - continue output = mod( output, pos_emb, @@ -942,7 +861,7 @@ class SimpleCombiner(torch.nn.Module): weight1 = self.weight1 - if self.training and random.random() < 0.25 and self.min_weight != (0., 0.): + if self.training and random.random() < 0.5 and self.min_weight != (0., 0.): weight1 = weight1.clamp(min=self.min_weight[0], max=1.0-self.min_weight[1])