diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 21679589f..1bf4593d4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -89,6 +89,17 @@ LRSchedulerType = Union[ ] +def set_batch_count( + model: Union[nn.Module, DDP], batch_count: float +) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, 'batch_count'): + module.batch_count = batch_count + + def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", @@ -809,6 +820,7 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) scheduler.step_batch(params.batch_idx_train) scaler.step(optimizer) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 19c52be38..c7cc4440a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -41,6 +41,7 @@ from scaling import ( from torch import Tensor, nn from icefall.utils import make_pad_mask +from icefall.dist import get_rank class Zipformer(EncoderInterface): @@ -84,15 +85,6 @@ class Zipformer(EncoderInterface): self.zipformer_downsampling_factors = zipformer_downsampling_factors self.output_downsampling_factor = output_downsampling_factor - # keep track of how many times forward() has been called, for purposes - # of warmup. do this with a floating-point count because integer counts - # can fail to survive model averaging. initialize with a smallish - # random number so that different encoders use different random seeds in - # shared_rng get_layers_to_drop() while using the same random seeds - # across jobs. - self.register_buffer('warmup_count', torch.tensor(float(10.0 * random.random()))) - self.warmup_end = warmup_batches - for u,d in zip(encoder_unmasked_dims, encoder_dims): assert u <= d @@ -140,70 +132,11 @@ class Zipformer(EncoderInterface): encoders.append(encoder) self.encoders = nn.ModuleList(encoders) - # initializes self.skip_layers and self.skip_modules - self._init_skip_modules() self.downsample_output = AttentionDownsample(encoder_dims[-1], encoder_dims[-1], downsample=output_downsampling_factor) - - def get_warmup_count(self) -> float: - """ - Returns a value that reflects how many times this function has been called in training mode. - """ - ans = self.warmup_count.item() - if self.training: - if ans > 1000000.0: - # this ensures that as the number of batches gets large, the warmup count cycles rather - # than getting stuck at the smallest floating point value x such that x + 1 == x. - # this is necessary because get_layers_to_drop() relies on the warmup count changing - # on every batch. - next_count = 500000.0 - else: - next_count = ans + 1.0 - self.warmup_count.fill_(next_count) - return ans - - def _get_layer_skip_dropout_prob(self): - if not self.training: - return 0.0 - warmup_count = self.get_warmup_count() - min_dropout_prob = 0.025 - - if warmup_count > self.warmup_end: - return min_dropout_prob - else: - return 0.5 - (warmup_count / self.warmup_end) * (0.5 - min_dropout_prob) - - def _init_skip_modules(self): - """ - If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer - indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of - layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2, - we combine the outputs of layers 1 and 5. - """ - skip_layers = [] - skip_modules = [] - z = self.zipformer_downsampling_factors - for i in range(len(z)): - if i <= 1 or z[i-1] <= z[i]: - skip_layers.append(None) - skip_modules.append(nn.Identity()) - else: - # TEMP - for j in range(i-2, -1, -1): - if z[j] <= z[i] or j == 0: - # TEMP logging statement. - logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " - f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.") - skip_layers.append(j) - skip_modules.append(SimpleCombiner(self.encoder_dims[j], - self.encoder_dims[i-1])) - break - self.skip_layers = skip_layers - self.skip_modules = nn.ModuleList(skip_modules) - def get_feature_masks( self, x: torch.Tensor) -> List[Union[float, Tensor]]: @@ -288,25 +221,20 @@ class Zipformer(EncoderInterface): assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) - outputs = [] feature_masks = self.get_feature_masks(x) for i, module in enumerate(self.encoders): ds = self.zipformer_downsampling_factors[i] - if self.skip_layers[i] is not None: - layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() - if (not self.training) or random.random() > layer_skip_dropout_prob: - x = self.skip_modules[i](outputs[self.skip_layers[i]], x) x = module(x, feature_mask=feature_masks[i], src_key_padding_mask=None if mask is None else mask[...,::ds]) - outputs.append(x) x = self.downsample_output(x) # class Downsample has this rounding behavior.. assert self.output_downsampling_factor == 2 lengths = (lengths + 1) // 2 + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) return x, lengths @@ -436,13 +364,12 @@ class ZipformerEncoderLayer(nn.Module): delta = src - src_orig bypass_scale = self.bypass_scale - if self.training and random.random() > 0.1: - # with probability 0.9, in training mode clamp bypass_scale to [ - # 0.3, 1.0 ]; this will encourage it to learn parameters within this - # range by making parameters that are outside that range range - # noisy. For testing don't bother, as it will anyway end up - # learning values within this range or very close to it. - bypass_scale = bypass_scale.clamp(min=0.3, max=1.0) + if torch.jit.is_scripting() or (not self.training) or random.random() > 0.1: + # with probability 0.9, in training mode, or always, in testing + # mode, clamp bypass_scale to [ 0.1, 1.0 ]; this will encourage it + # to learn parameters within this range by making parameters that + # are outside that range range noisy. + bypass_scale = bypass_scale.clamp(min=0.5, max=1.0) src = src_orig + delta * bypass_scale return self.whiten(src) @@ -470,17 +397,16 @@ class ZipformerEncoder(nn.Module): warmup_end: float ) -> None: super().__init__() - - # keep track of how many times forward() has been called, for purposes - # of warmup. do this with a floating-point count because integer counts - # can fail to survive model averaging. initialize with a smallish - # random number so that different encoders use different random seeds in - # shared_rng get_layers_to_drop() while using the same random seeds - # across jobs. - self.register_buffer('warmup_count', torch.tensor(float(10.0 * random.random()))) - + # self.batch_count will be written to by the top-level training program. + # Note: in inference time this may be zero but should be treated as large, + # we can check if self.training is true. + self.batch_count = 0 self.warmup_begin = warmup_begin self.warmup_end = warmup_end + # module_seed is for when we need a random number that is unique to the module but + # shared across jobs. It's used to randomly select how many layers to drop, + # so that we can keep this consistent across worker tasks (for efficiency). + self.module_seed = torch.randint(0, 1000, ()).item() self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) @@ -501,27 +427,12 @@ class ZipformerEncoder(nn.Module): self.layers[i].warmup_end = cur_begin + def get_layers_to_drop(self, rnd_seed: int): + ans = set() + if not self.training: + return ans - def get_warmup_count(self) -> float: - """ - Returns a value that reflects how many times this function has been called in training mode. - """ - ans = self.warmup_count.item() - if self.training: - if ans > 1000000.0: - # this ensures that as the number of batches gets large, the warmup count cycles rather - # than getting stuck at the smallest floating point value x such that x + 1 == x. - # this is necessary because get_layers_to_drop() relies on the warmup count changing - # on every batch. - next_count = 500000.0 - else: - next_count = ans + 1.0 - self.warmup_count.fill_(next_count) - return ans - - - def get_layers_to_drop(self, rnd_seed: int, warmup_count: float): - + batch_count = self.batch_count num_layers = len(self.layers) def get_layerdrop_prob(layer: int) -> float: @@ -531,30 +442,26 @@ class ZipformerEncoder(nn.Module): initial_layerdrop_prob = 0.5 final_layerdrop_prob = 0.05 - if warmup_count < 20.0: - # As a special case, if warmup_count < 20.0 return 0 (drop no + if batch_count == 0: + # As a special case, if batch_count == 0, return 0 (drop no # layers). This is rather ugly, I'm afraid; it is intended to # enable our scan_pessimistic_batches_for_oom() code to work correctly # so if we are going to get OOM it will happen early. - # also search for 'warmup_count' with quotes in this file to see + # also search for 'batch_count' with quotes in this file to see # how we initialize the warmup count to a random number between # 0 and 10. return 0.0 - elif warmup_count < layer_warmup_begin: + elif batch_count < layer_warmup_begin: return initial_layerdrop_prob - elif warmup_count > layer_warmup_end: + elif batch_count > layer_warmup_end: return final_layerdrop_prob else: # linearly interpolate - t = (warmup_count - layer_warmup_begin) / layer_warmup_end + t = (batch_count - layer_warmup_begin) / layer_warmup_end assert 0.0 <= t < 1.001 return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob) - ans = set() - if not self.training: - return ans - - shared_rng = random.Random(int(warmup_count * 1000)) + shared_rng = random.Random(batch_count + self.module_seed) independent_rng = random.Random(rnd_seed) layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ] @@ -577,7 +484,8 @@ class ZipformerEncoder(nn.Module): if len(ans) == num_to_drop: break if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") + logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " + f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") return ans @@ -611,7 +519,7 @@ class ZipformerEncoder(nn.Module): rnd_seed = src.numel() + random.randint(0, 1000) - layers_to_drop = self.get_layers_to_drop(rnd_seed, self.get_warmup_count()) + layers_to_drop = self.get_layers_to_drop(rnd_seed) output = output * feature_mask