Rework how warmup count is produced; should not affect results.

This commit is contained in:
Daniel Povey 2022-10-30 14:17:41 +08:00
parent 6b6143f28c
commit 8b0722e626
2 changed files with 43 additions and 123 deletions

View File

@ -89,6 +89,17 @@ LRSchedulerType = Union[
]
def set_batch_count(
model: Union[nn.Module, DDP], batch_count: float
) -> None:
if isinstance(model, DDP):
# get underlying nn.Module
model = model.module
for module in model.modules():
if hasattr(module, 'batch_count'):
module.batch_count = batch_count
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-encoder-layers",
@ -809,6 +820,7 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
set_batch_count(model, params.batch_idx_train)
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)

View File

@ -41,6 +41,7 @@ from scaling import (
from torch import Tensor, nn
from icefall.utils import make_pad_mask
from icefall.dist import get_rank
class Zipformer(EncoderInterface):
@ -84,15 +85,6 @@ class Zipformer(EncoderInterface):
self.zipformer_downsampling_factors = zipformer_downsampling_factors
self.output_downsampling_factor = output_downsampling_factor
# keep track of how many times forward() has been called, for purposes
# of warmup. do this with a floating-point count because integer counts
# can fail to survive model averaging. initialize with a smallish
# random number so that different encoders use different random seeds in
# shared_rng get_layers_to_drop() while using the same random seeds
# across jobs.
self.register_buffer('warmup_count', torch.tensor(float(10.0 * random.random())))
self.warmup_end = warmup_batches
for u,d in zip(encoder_unmasked_dims, encoder_dims):
assert u <= d
@ -140,70 +132,11 @@ class Zipformer(EncoderInterface):
encoders.append(encoder)
self.encoders = nn.ModuleList(encoders)
# initializes self.skip_layers and self.skip_modules
self._init_skip_modules()
self.downsample_output = AttentionDownsample(encoder_dims[-1],
encoder_dims[-1],
downsample=output_downsampling_factor)
def get_warmup_count(self) -> float:
"""
Returns a value that reflects how many times this function has been called in training mode.
"""
ans = self.warmup_count.item()
if self.training:
if ans > 1000000.0:
# this ensures that as the number of batches gets large, the warmup count cycles rather
# than getting stuck at the smallest floating point value x such that x + 1 == x.
# this is necessary because get_layers_to_drop() relies on the warmup count changing
# on every batch.
next_count = 500000.0
else:
next_count = ans + 1.0
self.warmup_count.fill_(next_count)
return ans
def _get_layer_skip_dropout_prob(self):
if not self.training:
return 0.0
warmup_count = self.get_warmup_count()
min_dropout_prob = 0.025
if warmup_count > self.warmup_end:
return min_dropout_prob
else:
return 0.5 - (warmup_count / self.warmup_end) * (0.5 - min_dropout_prob)
def _init_skip_modules(self):
"""
If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer
indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of
layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2,
we combine the outputs of layers 1 and 5.
"""
skip_layers = []
skip_modules = []
z = self.zipformer_downsampling_factors
for i in range(len(z)):
if i <= 1 or z[i-1] <= z[i]:
skip_layers.append(None)
skip_modules.append(nn.Identity())
else:
# TEMP
for j in range(i-2, -1, -1):
if z[j] <= z[i] or j == 0:
# TEMP logging statement.
logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.")
skip_layers.append(j)
skip_modules.append(SimpleCombiner(self.encoder_dims[j],
self.encoder_dims[i-1]))
break
self.skip_layers = skip_layers
self.skip_modules = nn.ModuleList(skip_modules)
def get_feature_masks(
self,
x: torch.Tensor) -> List[Union[float, Tensor]]:
@ -288,25 +221,20 @@ class Zipformer(EncoderInterface):
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
outputs = []
feature_masks = self.get_feature_masks(x)
for i, module in enumerate(self.encoders):
ds = self.zipformer_downsampling_factors[i]
if self.skip_layers[i] is not None:
layer_skip_dropout_prob = self._get_layer_skip_dropout_prob()
if (not self.training) or random.random() > layer_skip_dropout_prob:
x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
x = module(x,
feature_mask=feature_masks[i],
src_key_padding_mask=None if mask is None else mask[...,::ds])
outputs.append(x)
x = self.downsample_output(x)
# class Downsample has this rounding behavior..
assert self.output_downsampling_factor == 2
lengths = (lengths + 1) // 2
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths
@ -436,13 +364,12 @@ class ZipformerEncoderLayer(nn.Module):
delta = src - src_orig
bypass_scale = self.bypass_scale
if self.training and random.random() > 0.1:
# with probability 0.9, in training mode clamp bypass_scale to [
# 0.3, 1.0 ]; this will encourage it to learn parameters within this
# range by making parameters that are outside that range range
# noisy. For testing don't bother, as it will anyway end up
# learning values within this range or very close to it.
bypass_scale = bypass_scale.clamp(min=0.3, max=1.0)
if torch.jit.is_scripting() or (not self.training) or random.random() > 0.1:
# with probability 0.9, in training mode, or always, in testing
# mode, clamp bypass_scale to [ 0.1, 1.0 ]; this will encourage it
# to learn parameters within this range by making parameters that
# are outside that range range noisy.
bypass_scale = bypass_scale.clamp(min=0.5, max=1.0)
src = src_orig + delta * bypass_scale
return self.whiten(src)
@ -470,17 +397,16 @@ class ZipformerEncoder(nn.Module):
warmup_end: float
) -> None:
super().__init__()
# keep track of how many times forward() has been called, for purposes
# of warmup. do this with a floating-point count because integer counts
# can fail to survive model averaging. initialize with a smallish
# random number so that different encoders use different random seeds in
# shared_rng get_layers_to_drop() while using the same random seeds
# across jobs.
self.register_buffer('warmup_count', torch.tensor(float(10.0 * random.random())))
# self.batch_count will be written to by the top-level training program.
# Note: in inference time this may be zero but should be treated as large,
# we can check if self.training is true.
self.batch_count = 0
self.warmup_begin = warmup_begin
self.warmup_end = warmup_end
# module_seed is for when we need a random number that is unique to the module but
# shared across jobs. It's used to randomly select how many layers to drop,
# so that we can keep this consistent across worker tasks (for efficiency).
self.module_seed = torch.randint(0, 1000, ()).item()
self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model,
dropout)
@ -501,27 +427,12 @@ class ZipformerEncoder(nn.Module):
self.layers[i].warmup_end = cur_begin
def get_layers_to_drop(self, rnd_seed: int):
ans = set()
if not self.training:
return ans
def get_warmup_count(self) -> float:
"""
Returns a value that reflects how many times this function has been called in training mode.
"""
ans = self.warmup_count.item()
if self.training:
if ans > 1000000.0:
# this ensures that as the number of batches gets large, the warmup count cycles rather
# than getting stuck at the smallest floating point value x such that x + 1 == x.
# this is necessary because get_layers_to_drop() relies on the warmup count changing
# on every batch.
next_count = 500000.0
else:
next_count = ans + 1.0
self.warmup_count.fill_(next_count)
return ans
def get_layers_to_drop(self, rnd_seed: int, warmup_count: float):
batch_count = self.batch_count
num_layers = len(self.layers)
def get_layerdrop_prob(layer: int) -> float:
@ -531,30 +442,26 @@ class ZipformerEncoder(nn.Module):
initial_layerdrop_prob = 0.5
final_layerdrop_prob = 0.05
if warmup_count < 20.0:
# As a special case, if warmup_count < 20.0 return 0 (drop no
if batch_count == 0:
# As a special case, if batch_count == 0, return 0 (drop no
# layers). This is rather ugly, I'm afraid; it is intended to
# enable our scan_pessimistic_batches_for_oom() code to work correctly
# so if we are going to get OOM it will happen early.
# also search for 'warmup_count' with quotes in this file to see
# also search for 'batch_count' with quotes in this file to see
# how we initialize the warmup count to a random number between
# 0 and 10.
return 0.0
elif warmup_count < layer_warmup_begin:
elif batch_count < layer_warmup_begin:
return initial_layerdrop_prob
elif warmup_count > layer_warmup_end:
elif batch_count > layer_warmup_end:
return final_layerdrop_prob
else:
# linearly interpolate
t = (warmup_count - layer_warmup_begin) / layer_warmup_end
t = (batch_count - layer_warmup_begin) / layer_warmup_end
assert 0.0 <= t < 1.001
return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob)
ans = set()
if not self.training:
return ans
shared_rng = random.Random(int(warmup_count * 1000))
shared_rng = random.Random(batch_count + self.module_seed)
independent_rng = random.Random(rnd_seed)
layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ]
@ -577,7 +484,8 @@ class ZipformerEncoder(nn.Module):
if len(ans) == num_to_drop:
break
if shared_rng.random() < 0.005 or __name__ == "__main__":
logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, "
f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
return ans
@ -611,7 +519,7 @@ class ZipformerEncoder(nn.Module):
rnd_seed = src.numel() + random.randint(0, 1000)
layers_to_drop = self.get_layers_to_drop(rnd_seed, self.get_warmup_count())
layers_to_drop = self.get_layers_to_drop(rnd_seed)
output = output * feature_mask