Implement efficient layer dropout

This commit is contained in:
Daniel Povey 2022-10-03 17:19:16 +08:00
parent 93dff29243
commit b3af9f67ae

View File

@ -60,7 +60,7 @@ class Conformer(EncoderInterface):
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
dropout: float = 0.1,
layer_dropout: float = 0.075,
layer_dropout: float = 0.333,
cnn_module_kernel: int = 31,
aux_layer_period: int = 3,
) -> None:
@ -85,13 +85,13 @@ class Conformer(EncoderInterface):
nhead,
dim_feedforward,
dropout,
layer_dropout,
cnn_module_kernel,
)
self.encoder = ConformerEncoder(
encoder_layer,
num_encoder_layers,
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
layer_dropout=layer_dropout,
)
@ -160,13 +160,10 @@ class ConformerEncoderLayer(nn.Module):
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
layer_dropout: float = 0.075,
cnn_module_kernel: int = 31,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.layer_dropout = layer_dropout
self.d_model = d_model
self.self_attn = RelPositionMultiheadAttention(
@ -215,53 +212,80 @@ class ConformerEncoderLayer(nn.Module):
def forward(
self,
src: Tensor,
feature_mask: Union[Tensor, float],
pos_emb: Tensor,
attn_scores_in: Optional[Tensor] = None,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
) -> Tensor:
batch_split: Optional[bool] = None,
layerdrop_indicator: float = 1.0,
) -> Tuple[Tensor, Tensor]:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
attn_scores_in: something with the dimension fo attention weights (bsz * num_heads, len, len) that is
attn_scores_in: something with the dimension fo attention weights (bsz, len, len, num_heads) that is
passed from layer to layer.
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
feature_mask: a mask of shape (S, N, E), that randomly zeroes out
some of the features on each frame.
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
batch_split: if not None, this layer will only be applied to
part of the batch. if True we apply it to the first half of the batch
elements, otherwise to the second half.
layerdrop_indicator: a float. It is supposed to be 1.0 if nothing is dropped out,
and 0.0 if something is dropped out. You don't have to set this directly,
it is set internally if you provide the batch_split option as non-None.
Shape:
src: (S, N, E).
feature_mask: float, or (S, N, 1)
pos_emb: (N, 2*S-1, E)
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
if batch_split is not None:
process_first_half = batch_split
batch_size = src.shape[1]
mid = batch_size // 2
if attn_scores_in is None:
seq_len = src.shape[0]
num_heads = self.self_attn.num_heads
attn_scores_in = torch.zeros(1, 1, 1, 1, device=src.device, dtype=src.dtype).expand(
batch_size, seq_len, seq_len, num_heads)
attn_scores_a, attn_scores_b = attn_scores_in[:mid], attn_scores_in[mid:]
src_a, src_b = src[:, :mid], src[:, mid:]
key_padding_a, key_padding_b = src_key_padding_mask[:mid], src_key_padding_mask[mid:],
if process_first_half:
src_a, attn_scores_a = self.forward(src_a, pos_emb, attn_scores_a, src_mask,
key_padding_a, warmup, batch_split=None,
layerdrop_indicator=0.0)
else:
src_b, attn_scores_b = self.forward(src_b, pos_emb, attn_scores_b, src_mask,
key_padding_b, warmup, batch_split=None,
layerdrop_indicator=0.0)
return torch.cat((src_a, src_b), dim=1), torch.cat((attn_scores_a, attn_scores_b), dim=0)
src_orig = src
warmup_scale = min(0.1 + warmup, 1.0)
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
# completely bypass it.
if self.training:
alpha = (
warmup_scale
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
else 0.1
)
else:
alpha = 1.0
alpha = warmup_scale if self.training else 1.0
# macaron style feed forward module
src = src + self.feed_forward_macaron_scale(self.feed_forward_macaron(src),
feature_mask)
layerdrop_indicator)
# multi-headed self-attention module
src_att, _, attn_scores_out = self.self_attn(
@ -271,18 +295,18 @@ class ConformerEncoderLayer(nn.Module):
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)
src = src + self.self_attn_scale(src_att, feature_mask)
src = src + self.self_attn_scale(src_att, layerdrop_indicator)
# convolution module
src = src + self.conv_scale(self.conv_module(src, src_key_padding_mask=src_key_padding_mask),
feature_mask)
layerdrop_indicator)
# feed forward module
src = src + self.feed_forward_scale(self.feed_forward(src),
feature_mask)
layerdrop_indicator)
src = self.final_scale(src, feature_mask)
src = self.final_scale(src, layerdrop_indicator)
src = self.norm_final(self.balancer(src))
@ -312,8 +336,12 @@ class ConformerEncoder(nn.Module):
encoder_layer: nn.Module,
num_layers: int,
aux_layers: List[int],
layer_dropout: float = 0.333
) -> None:
super().__init__()
assert 0 < layer_dropout < 0.5
self.layer_dropout = layer_dropout
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
@ -359,34 +387,56 @@ class ConformerEncoder(nn.Module):
attn_scores = None
if self.training:
# deal with feature masking.
if not self.training:
feature_mask = 1.0
else:
# feature mask.
# on 0.25 of the frames, drop out the extra features [force a bottleneck.]
feature_mask_dropout_prob = 0.15
feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked.
full_feature_mask = torch.ones_like(src) # S, N, E
# feature_mask is 0 with probability `feature_mask_dropout_prob`
# feature_mask shape: (S, N, 1)
feature_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype)
full_feature_mask[..., feature_unmasked_dim:] *= feature_mask
else:
feature_mask = 1.0
full_feature_mask = 1.0
feature_mask = torch.ones_like(src) # S, N, E
# frame_mask is 0 with probability `feature_mask_dropout_prob`
# frame_mask shape: (S, N, 1)
frame_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype)
feature_mask[..., feature_unmasked_dim:] *= frame_mask
src = src * full_feature_mask
# deal with layer dropout.
batch_size = src.shape[1]
if not self.training or batch_size == 1:
dropped_layer_pairs = set() # empty set.
else:
num_layer_pairs = len(self.layers) // 2
layer_pairs = list(range(num_layer_pairs))
random.shuffle(layer_pairs)
# the * 2 is because we only drop out one layer from each pair:
# half for one half of the batch and the other half for the other.
num_dropped_pairs = int(self.layer_dropout * 2 * num_layer_pairs)
dropped_layer_pairs = set(layer_pairs[:num_dropped_pairs])
rand_bool = (random.random() < 0.5)
src = src * feature_mask
for i, mod in enumerate(self.layers):
if i // 2 not in dropped_layer_pairs:
batch_split = None # no layer dropout
else:
batch_split = rand_bool if i % 2 == 0 else not rand_bool
output, attn_scores = mod(
output,
feature_mask,
pos_emb,
attn_scores,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
batch_split=batch_split,
)
output = output * full_feature_mask
output = output * feature_mask
if i in self.aux_layers:
outputs.append(output)