mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Do warmup by dropping out whole layers.
This commit is contained in:
parent
5255969544
commit
e6540865f3
@ -260,11 +260,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
self.d_model = d_model
|
||||
|
||||
# we'll overwrite these warmup_begin and warmup_end values from init of
|
||||
# class ConformerEncoder.
|
||||
self.warmup_begin = 0.0
|
||||
self.warmup_end = 1000.0
|
||||
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, nhead, dropout=dropout,
|
||||
)
|
||||
@ -304,18 +299,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
max_var_per_eig=0.2,
|
||||
)
|
||||
|
||||
def get_warmup_value(self, warmup_count: float) -> float:
|
||||
"""
|
||||
Returns a value that is 0 at the start of training and increases to 1.0 during
|
||||
a warmup period specified during model initialization.
|
||||
"""
|
||||
if warmup_count < self.warmup_begin:
|
||||
return 0.0
|
||||
elif warmup_count > self.warmup_end:
|
||||
return 1.0
|
||||
else:
|
||||
return (warmup_count - self.warmup_begin) / (self.warmup_end - self.warmup_begin)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
@ -323,7 +306,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
attn_scores_in: Optional[Tensor] = None,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup_count: float = 1.0e+10,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
@ -368,16 +350,10 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
warmup_value = self.get_warmup_value(warmup_count)
|
||||
|
||||
delta = src - src_orig
|
||||
if warmup_value < 1.0 and self.training:
|
||||
keep_prob = 0.5 + 0.5 * warmup_value
|
||||
# the :1 means the mask is chosen per frame.
|
||||
delta = delta * (torch.rand_like(delta[...,:1]) < keep_prob)
|
||||
bypass_scale = self.bypass_scale
|
||||
if random.random() > 0.1:
|
||||
bypass_scale = bypass_scale.clamp(min=0.1)
|
||||
bypass_scale = bypass_scale.clamp(min=0.1, max=1.0)
|
||||
src = src_orig + delta * self.bypass_scale
|
||||
|
||||
return src, attn_scores_out
|
||||
@ -411,8 +387,8 @@ class ConformerEncoder(nn.Module):
|
||||
# fail to survive model averaging.
|
||||
self.register_buffer('warmup_count', torch.tensor(0.0))
|
||||
|
||||
# if this assert fails, increase the numbers in get_warmup_count().
|
||||
assert warmup_end <= 1000000.0
|
||||
self.warmup_begin = warmup_begin
|
||||
self.warmup_end = warmup_end
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model,
|
||||
dropout)
|
||||
@ -453,25 +429,40 @@ class ConformerEncoder(nn.Module):
|
||||
|
||||
|
||||
def get_layers_to_drop(self, warmup_count: float):
|
||||
|
||||
num_layers = len(self.layers)
|
||||
warmup_begin = self.warmup_begin
|
||||
warmup_end = self.warmup_end
|
||||
|
||||
def get_layerdrop_prob(layer: int) -> float:
|
||||
layer_warmup_delta = (warmup_end - warmup_begin) / num_layers
|
||||
layer_warmup_begin = warmup_begin + layer * layer_warmup_delta
|
||||
initial_layerdrop_prob = 0.75
|
||||
final_layerdrop_prob = 0.05
|
||||
|
||||
layer_warmup_end = layer_warmup_begin + layer_warmup_delta
|
||||
if warmup_count < layer_warmup_begin:
|
||||
return initial_layerdrop_prob
|
||||
elif warmup_count > layer_warmup_end:
|
||||
return final_layerdrop_prob
|
||||
else:
|
||||
# linearly interpolate
|
||||
t = (warmup_count - layer_warmup_begin) / layer_warmup_end
|
||||
assert 0.0 <= t < 1.001
|
||||
return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob)
|
||||
|
||||
ans = set()
|
||||
if not self.training:
|
||||
return ans
|
||||
# We use a random number generator seeded from warmup_count because
|
||||
# if there are multiple training processes we want them to all drop the
|
||||
# same number of layers (not necessarily the same layers though). This
|
||||
# will tend to minimize training time.
|
||||
rng = random.Random(int(warmup_count))
|
||||
num_layers = len(self.layers)
|
||||
|
||||
# x is the expected number of layers to drop
|
||||
x = 0.075 * num_layers
|
||||
# integerize x in a way that preserves sxpectations.
|
||||
num_layers_to_drop = int(x) + int(rng.random() < (x - int(x)))
|
||||
while (len(ans) < num_layers_to_drop):
|
||||
# use random, not rng here, because we don't want the same specific layers to be dropped.
|
||||
ans.add(random.randrange(0, num_layers))
|
||||
for layer in range(num_layers):
|
||||
if random.random() < get_layerdrop_prob(layer):
|
||||
ans.add(layer)
|
||||
if random.random() < 0.005 or __name__ == "__main__":
|
||||
logging.info(f"warmup_begin={warmup_begin}, warmup_end={warmup_end}, warmup_count={warmup_count}, layers_to_drop={ans}")
|
||||
return ans
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
@ -497,14 +488,13 @@ class ConformerEncoder(nn.Module):
|
||||
|
||||
Returns: (x, x_no_combine), both of shape (S, N, E)
|
||||
"""
|
||||
warmup_count = self.get_warmup_count() # reflects number of training batches.
|
||||
pos_emb = self.encoder_pos(src)
|
||||
output = src
|
||||
|
||||
outputs = []
|
||||
attn_scores = None
|
||||
|
||||
layers_to_drop = self.get_layers_to_drop(warmup_count)
|
||||
layers_to_drop = self.get_layers_to_drop(self.get_warmup_count())
|
||||
|
||||
output = output * feature_mask
|
||||
|
||||
@ -517,7 +507,6 @@ class ConformerEncoder(nn.Module):
|
||||
attn_scores,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup_count=warmup_count,
|
||||
)
|
||||
|
||||
output = output * feature_mask
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user