Remove layer dropout and model-level warmup

This commit is contained in:
Daniel Povey 2022-10-06 12:36:42 +08:00
parent 537c3537c0
commit 0685ac792d

View File

@ -203,7 +203,6 @@ class ConformerEncoderLayer(nn.Module):
attn_scores_in: Optional[Tensor] = None,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
layerdrop_scale: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""
Pass the input through the encoder layer.
@ -216,8 +215,6 @@ class ConformerEncoderLayer(nn.Module):
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
batch_split: if not None, this layer will only be applied to
layerdrop_scale: an optional Tensor of broadcasting with `src` that will be used as a scale
on the change in the embeddings made by this layer.
Shape:
src: (S, N, E).
@ -286,76 +283,9 @@ class ConformerEncoder(nn.Module):
)
self.num_layers = num_layers
self.to_layerdrop_scales = nn.Sequential(
ScaledLinear(num_layers, 256, initial_scale=0.5),
nn.ReLU(),
ScaledLinear(256, num_layers, initial_scale=0.01))
num_channels = encoder_layer.norm_final.num_channels
def get_layerdrop_info(self) -> Tuple[Tensor, Optional[Tensor]]:
"""
Gets some random information that dictates layer dropout configuration.
Args:
batch_size: the number of sequences in the batch
Returns:
(mask, layerdrop_scales)
where:
layerdrop_mask is a CPU tensor of shape (num_layers,),
containing 1.0 for layers to be evaluated and 0.0
for layers not to be evaluated. It has constraints: at least one of two
halves of each layer must be evaluated, and successive layers of the same half
pmust be evaluated.
layerdrop_scales is a learned Tensor of shape (num_layers, batch_size, 1) of the form:
1.0 + [learned matrix * (1.0 - layerdrop_scale)]
where layerdrop_scale is 1.0 for layers that computed, for this half, and
0.0 for layers not computed. This is intended to learn that layers neighboring
layers that were not computed should get a higher scale to "make up" for the missing
computation.
The reason for the specific functional form is to constrain so that if everything
is computed (layerdrop_scale is all 1.0), this is constrained to be 1.0, to avoid
introducing redundant degrees of freedom.
"""
num_layers = self.num_layers
# This ensures that if we are using multiple worker processes, they all use the same
# random numbers, so they will all take about the same amount of time to process
# the batch.
rng = random.Random(self.count)
self.count += 1
def get_random_mask():
# 1.0 means don't drop the layer, 0.0 means drop the layer
mask = torch.ones(num_layers, device='cpu')
if not self.training:
return mask
r = rng.random()
if r < 0.1:
# drop zero layers, to make sure that sometimes we see the complete network.
return mask
final_layers_dropped = 0
if r < 0.1 + 0.25:
# with prob 0.25: completely drop the last n layers. let n
# be a multiple of 3 (this is what we used to do with aux_layers).
final_layers_dropped = 3 * rng.randint(1, num_layers // 3)
mask[-final_layers_dropped:] = 0.0
layer_drop_prob = 0.075
for i in range(num_layers - final_layers_dropped):
mask[i] = (rng.random() > layer_drop_prob)
if mask.sum() == 0.0:
mask[0] = 1.0
return mask
mask = get_random_mask()
device = self.to_layerdrop_scales[0].weight.device
layerdrop_scales = 1.0 + self.to_layerdrop_scales(mask.to(device))
if random.random() < 0.05:
logging.info(f"mask={mask}, layerdrop_scales = {layerdrop_scales.to('cpu')}")
return mask, layerdrop_scales
def forward(
@ -402,24 +332,18 @@ class ConformerEncoder(nn.Module):
frame_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype)
feature_mask[..., feature_unmasked_dim:] *= frame_mask
# deal with layer dropout.
layerdrop_mask, layerdrop_scales = self.get_layerdrop_info()
output = output * feature_mask
for i, mod in enumerate(self.layers):
if layerdrop_mask[i] != 0.0:
output, attn_scores = mod(
output,
pos_emb,
attn_scores,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
layerdrop_scale=layerdrop_scales[i],
)
output = output * feature_mask
output, attn_scores = mod(
output,
pos_emb,
attn_scores,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
)
output = output * feature_mask
return output