mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Rework how warmup count is produced; should not affect results.
This commit is contained in:
parent
6b6143f28c
commit
8b0722e626
@ -89,6 +89,17 @@ LRSchedulerType = Union[
|
||||
]
|
||||
|
||||
|
||||
def set_batch_count(
|
||||
model: Union[nn.Module, DDP], batch_count: float
|
||||
) -> None:
|
||||
if isinstance(model, DDP):
|
||||
# get underlying nn.Module
|
||||
model = model.module
|
||||
for module in model.modules():
|
||||
if hasattr(module, 'batch_count'):
|
||||
module.batch_count = batch_count
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-encoder-layers",
|
||||
@ -809,6 +820,7 @@ def train_one_epoch(
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
set_batch_count(model, params.batch_idx_train)
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
|
||||
scaler.step(optimizer)
|
||||
|
||||
@ -41,6 +41,7 @@ from scaling import (
|
||||
from torch import Tensor, nn
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
from icefall.dist import get_rank
|
||||
|
||||
|
||||
class Zipformer(EncoderInterface):
|
||||
@ -84,15 +85,6 @@ class Zipformer(EncoderInterface):
|
||||
self.zipformer_downsampling_factors = zipformer_downsampling_factors
|
||||
self.output_downsampling_factor = output_downsampling_factor
|
||||
|
||||
# keep track of how many times forward() has been called, for purposes
|
||||
# of warmup. do this with a floating-point count because integer counts
|
||||
# can fail to survive model averaging. initialize with a smallish
|
||||
# random number so that different encoders use different random seeds in
|
||||
# shared_rng get_layers_to_drop() while using the same random seeds
|
||||
# across jobs.
|
||||
self.register_buffer('warmup_count', torch.tensor(float(10.0 * random.random())))
|
||||
self.warmup_end = warmup_batches
|
||||
|
||||
for u,d in zip(encoder_unmasked_dims, encoder_dims):
|
||||
assert u <= d
|
||||
|
||||
@ -140,70 +132,11 @@ class Zipformer(EncoderInterface):
|
||||
encoders.append(encoder)
|
||||
self.encoders = nn.ModuleList(encoders)
|
||||
|
||||
# initializes self.skip_layers and self.skip_modules
|
||||
self._init_skip_modules()
|
||||
|
||||
self.downsample_output = AttentionDownsample(encoder_dims[-1],
|
||||
encoder_dims[-1],
|
||||
downsample=output_downsampling_factor)
|
||||
|
||||
|
||||
def get_warmup_count(self) -> float:
|
||||
"""
|
||||
Returns a value that reflects how many times this function has been called in training mode.
|
||||
"""
|
||||
ans = self.warmup_count.item()
|
||||
if self.training:
|
||||
if ans > 1000000.0:
|
||||
# this ensures that as the number of batches gets large, the warmup count cycles rather
|
||||
# than getting stuck at the smallest floating point value x such that x + 1 == x.
|
||||
# this is necessary because get_layers_to_drop() relies on the warmup count changing
|
||||
# on every batch.
|
||||
next_count = 500000.0
|
||||
else:
|
||||
next_count = ans + 1.0
|
||||
self.warmup_count.fill_(next_count)
|
||||
return ans
|
||||
|
||||
def _get_layer_skip_dropout_prob(self):
|
||||
if not self.training:
|
||||
return 0.0
|
||||
warmup_count = self.get_warmup_count()
|
||||
min_dropout_prob = 0.025
|
||||
|
||||
if warmup_count > self.warmup_end:
|
||||
return min_dropout_prob
|
||||
else:
|
||||
return 0.5 - (warmup_count / self.warmup_end) * (0.5 - min_dropout_prob)
|
||||
|
||||
def _init_skip_modules(self):
|
||||
"""
|
||||
If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer
|
||||
indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of
|
||||
layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2,
|
||||
we combine the outputs of layers 1 and 5.
|
||||
"""
|
||||
skip_layers = []
|
||||
skip_modules = []
|
||||
z = self.zipformer_downsampling_factors
|
||||
for i in range(len(z)):
|
||||
if i <= 1 or z[i-1] <= z[i]:
|
||||
skip_layers.append(None)
|
||||
skip_modules.append(nn.Identity())
|
||||
else:
|
||||
# TEMP
|
||||
for j in range(i-2, -1, -1):
|
||||
if z[j] <= z[i] or j == 0:
|
||||
# TEMP logging statement.
|
||||
logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
|
||||
f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.")
|
||||
skip_layers.append(j)
|
||||
skip_modules.append(SimpleCombiner(self.encoder_dims[j],
|
||||
self.encoder_dims[i-1]))
|
||||
break
|
||||
self.skip_layers = skip_layers
|
||||
self.skip_modules = nn.ModuleList(skip_modules)
|
||||
|
||||
def get_feature_masks(
|
||||
self,
|
||||
x: torch.Tensor) -> List[Union[float, Tensor]]:
|
||||
@ -288,25 +221,20 @@ class Zipformer(EncoderInterface):
|
||||
assert x.size(0) == lengths.max().item()
|
||||
mask = make_pad_mask(lengths)
|
||||
|
||||
outputs = []
|
||||
feature_masks = self.get_feature_masks(x)
|
||||
|
||||
for i, module in enumerate(self.encoders):
|
||||
ds = self.zipformer_downsampling_factors[i]
|
||||
if self.skip_layers[i] is not None:
|
||||
layer_skip_dropout_prob = self._get_layer_skip_dropout_prob()
|
||||
if (not self.training) or random.random() > layer_skip_dropout_prob:
|
||||
x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
||||
x = module(x,
|
||||
feature_mask=feature_masks[i],
|
||||
src_key_padding_mask=None if mask is None else mask[...,::ds])
|
||||
outputs.append(x)
|
||||
|
||||
x = self.downsample_output(x)
|
||||
# class Downsample has this rounding behavior..
|
||||
assert self.output_downsampling_factor == 2
|
||||
lengths = (lengths + 1) // 2
|
||||
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return x, lengths
|
||||
@ -436,13 +364,12 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
delta = src - src_orig
|
||||
bypass_scale = self.bypass_scale
|
||||
if self.training and random.random() > 0.1:
|
||||
# with probability 0.9, in training mode clamp bypass_scale to [
|
||||
# 0.3, 1.0 ]; this will encourage it to learn parameters within this
|
||||
# range by making parameters that are outside that range range
|
||||
# noisy. For testing don't bother, as it will anyway end up
|
||||
# learning values within this range or very close to it.
|
||||
bypass_scale = bypass_scale.clamp(min=0.3, max=1.0)
|
||||
if torch.jit.is_scripting() or (not self.training) or random.random() > 0.1:
|
||||
# with probability 0.9, in training mode, or always, in testing
|
||||
# mode, clamp bypass_scale to [ 0.1, 1.0 ]; this will encourage it
|
||||
# to learn parameters within this range by making parameters that
|
||||
# are outside that range range noisy.
|
||||
bypass_scale = bypass_scale.clamp(min=0.5, max=1.0)
|
||||
src = src_orig + delta * bypass_scale
|
||||
|
||||
return self.whiten(src)
|
||||
@ -470,17 +397,16 @@ class ZipformerEncoder(nn.Module):
|
||||
warmup_end: float
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# keep track of how many times forward() has been called, for purposes
|
||||
# of warmup. do this with a floating-point count because integer counts
|
||||
# can fail to survive model averaging. initialize with a smallish
|
||||
# random number so that different encoders use different random seeds in
|
||||
# shared_rng get_layers_to_drop() while using the same random seeds
|
||||
# across jobs.
|
||||
self.register_buffer('warmup_count', torch.tensor(float(10.0 * random.random())))
|
||||
|
||||
# self.batch_count will be written to by the top-level training program.
|
||||
# Note: in inference time this may be zero but should be treated as large,
|
||||
# we can check if self.training is true.
|
||||
self.batch_count = 0
|
||||
self.warmup_begin = warmup_begin
|
||||
self.warmup_end = warmup_end
|
||||
# module_seed is for when we need a random number that is unique to the module but
|
||||
# shared across jobs. It's used to randomly select how many layers to drop,
|
||||
# so that we can keep this consistent across worker tasks (for efficiency).
|
||||
self.module_seed = torch.randint(0, 1000, ()).item()
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model,
|
||||
dropout)
|
||||
@ -501,27 +427,12 @@ class ZipformerEncoder(nn.Module):
|
||||
self.layers[i].warmup_end = cur_begin
|
||||
|
||||
|
||||
def get_layers_to_drop(self, rnd_seed: int):
|
||||
ans = set()
|
||||
if not self.training:
|
||||
return ans
|
||||
|
||||
def get_warmup_count(self) -> float:
|
||||
"""
|
||||
Returns a value that reflects how many times this function has been called in training mode.
|
||||
"""
|
||||
ans = self.warmup_count.item()
|
||||
if self.training:
|
||||
if ans > 1000000.0:
|
||||
# this ensures that as the number of batches gets large, the warmup count cycles rather
|
||||
# than getting stuck at the smallest floating point value x such that x + 1 == x.
|
||||
# this is necessary because get_layers_to_drop() relies on the warmup count changing
|
||||
# on every batch.
|
||||
next_count = 500000.0
|
||||
else:
|
||||
next_count = ans + 1.0
|
||||
self.warmup_count.fill_(next_count)
|
||||
return ans
|
||||
|
||||
|
||||
def get_layers_to_drop(self, rnd_seed: int, warmup_count: float):
|
||||
|
||||
batch_count = self.batch_count
|
||||
num_layers = len(self.layers)
|
||||
|
||||
def get_layerdrop_prob(layer: int) -> float:
|
||||
@ -531,30 +442,26 @@ class ZipformerEncoder(nn.Module):
|
||||
initial_layerdrop_prob = 0.5
|
||||
final_layerdrop_prob = 0.05
|
||||
|
||||
if warmup_count < 20.0:
|
||||
# As a special case, if warmup_count < 20.0 return 0 (drop no
|
||||
if batch_count == 0:
|
||||
# As a special case, if batch_count == 0, return 0 (drop no
|
||||
# layers). This is rather ugly, I'm afraid; it is intended to
|
||||
# enable our scan_pessimistic_batches_for_oom() code to work correctly
|
||||
# so if we are going to get OOM it will happen early.
|
||||
# also search for 'warmup_count' with quotes in this file to see
|
||||
# also search for 'batch_count' with quotes in this file to see
|
||||
# how we initialize the warmup count to a random number between
|
||||
# 0 and 10.
|
||||
return 0.0
|
||||
elif warmup_count < layer_warmup_begin:
|
||||
elif batch_count < layer_warmup_begin:
|
||||
return initial_layerdrop_prob
|
||||
elif warmup_count > layer_warmup_end:
|
||||
elif batch_count > layer_warmup_end:
|
||||
return final_layerdrop_prob
|
||||
else:
|
||||
# linearly interpolate
|
||||
t = (warmup_count - layer_warmup_begin) / layer_warmup_end
|
||||
t = (batch_count - layer_warmup_begin) / layer_warmup_end
|
||||
assert 0.0 <= t < 1.001
|
||||
return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob)
|
||||
|
||||
ans = set()
|
||||
if not self.training:
|
||||
return ans
|
||||
|
||||
shared_rng = random.Random(int(warmup_count * 1000))
|
||||
shared_rng = random.Random(batch_count + self.module_seed)
|
||||
independent_rng = random.Random(rnd_seed)
|
||||
|
||||
layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ]
|
||||
@ -577,7 +484,8 @@ class ZipformerEncoder(nn.Module):
|
||||
if len(ans) == num_to_drop:
|
||||
break
|
||||
if shared_rng.random() < 0.005 or __name__ == "__main__":
|
||||
logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
|
||||
logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, "
|
||||
f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
|
||||
return ans
|
||||
|
||||
|
||||
@ -611,7 +519,7 @@ class ZipformerEncoder(nn.Module):
|
||||
|
||||
|
||||
rnd_seed = src.numel() + random.randint(0, 1000)
|
||||
layers_to_drop = self.get_layers_to_drop(rnd_seed, self.get_warmup_count())
|
||||
layers_to_drop = self.get_layers_to_drop(rnd_seed)
|
||||
|
||||
output = output * feature_mask
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user