Implement efficient layer dropout

This commit is contained in:
Daniel Povey 2022-10-03 17:19:16 +08:00
parent 93dff29243
commit b3af9f67ae

View File

@ -60,7 +60,7 @@ class Conformer(EncoderInterface):
dim_feedforward: int = 2048, dim_feedforward: int = 2048,
num_encoder_layers: int = 12, num_encoder_layers: int = 12,
dropout: float = 0.1, dropout: float = 0.1,
layer_dropout: float = 0.075, layer_dropout: float = 0.333,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
aux_layer_period: int = 3, aux_layer_period: int = 3,
) -> None: ) -> None:
@ -85,13 +85,13 @@ class Conformer(EncoderInterface):
nhead, nhead,
dim_feedforward, dim_feedforward,
dropout, dropout,
layer_dropout,
cnn_module_kernel, cnn_module_kernel,
) )
self.encoder = ConformerEncoder( self.encoder = ConformerEncoder(
encoder_layer, encoder_layer,
num_encoder_layers, num_encoder_layers,
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)), aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
layer_dropout=layer_dropout,
) )
@ -160,13 +160,10 @@ class ConformerEncoderLayer(nn.Module):
nhead: int, nhead: int,
dim_feedforward: int = 2048, dim_feedforward: int = 2048,
dropout: float = 0.1, dropout: float = 0.1,
layer_dropout: float = 0.075,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.layer_dropout = layer_dropout
self.d_model = d_model self.d_model = d_model
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(
@ -215,53 +212,80 @@ class ConformerEncoderLayer(nn.Module):
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
feature_mask: Union[Tensor, float],
pos_emb: Tensor, pos_emb: Tensor,
attn_scores_in: Optional[Tensor] = None, attn_scores_in: Optional[Tensor] = None,
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0, warmup: float = 1.0,
) -> Tensor: batch_split: Optional[bool] = None,
layerdrop_indicator: float = 1.0,
) -> Tuple[Tensor, Tensor]:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
Args: Args:
src: the sequence to the encoder layer (required). src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).
attn_scores_in: something with the dimension fo attention weights (bsz * num_heads, len, len) that is attn_scores_in: something with the dimension fo attention weights (bsz, len, len, num_heads) that is
passed from layer to layer. passed from layer to layer.
src_mask: the mask for the src sequence (optional). src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
feature_mask: a mask of shape (S, N, E), that randomly zeroes out
some of the features on each frame.
warmup: controls selective bypass of of layers; if < 1.0, we will warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently. bypass layers more frequently.
batch_split: if not None, this layer will only be applied to
part of the batch. if True we apply it to the first half of the batch
elements, otherwise to the second half.
layerdrop_indicator: a float. It is supposed to be 1.0 if nothing is dropped out,
and 0.0 if something is dropped out. You don't have to set this directly,
it is set internally if you provide the batch_split option as non-None.
Shape: Shape:
src: (S, N, E). src: (S, N, E).
feature_mask: float, or (S, N, 1)
pos_emb: (N, 2*S-1, E) pos_emb: (N, 2*S-1, E)
src_mask: (S, S). src_mask: (S, S).
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number S is the source sequence length, N is the batch size, E is the feature number
""" """
if batch_split is not None:
process_first_half = batch_split
batch_size = src.shape[1]
mid = batch_size // 2
if attn_scores_in is None:
seq_len = src.shape[0]
num_heads = self.self_attn.num_heads
attn_scores_in = torch.zeros(1, 1, 1, 1, device=src.device, dtype=src.dtype).expand(
batch_size, seq_len, seq_len, num_heads)
attn_scores_a, attn_scores_b = attn_scores_in[:mid], attn_scores_in[mid:]
src_a, src_b = src[:, :mid], src[:, mid:]
key_padding_a, key_padding_b = src_key_padding_mask[:mid], src_key_padding_mask[mid:],
if process_first_half:
src_a, attn_scores_a = self.forward(src_a, pos_emb, attn_scores_a, src_mask,
key_padding_a, warmup, batch_split=None,
layerdrop_indicator=0.0)
else:
src_b, attn_scores_b = self.forward(src_b, pos_emb, attn_scores_b, src_mask,
key_padding_b, warmup, batch_split=None,
layerdrop_indicator=0.0)
return torch.cat((src_a, src_b), dim=1), torch.cat((attn_scores_a, attn_scores_b), dim=0)
src_orig = src src_orig = src
warmup_scale = min(0.1 + warmup, 1.0) warmup_scale = min(0.1 + warmup, 1.0)
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
# completely bypass it. # completely bypass it.
if self.training: alpha = warmup_scale if self.training else 1.0
alpha = (
warmup_scale
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
else 0.1
)
else:
alpha = 1.0
# macaron style feed forward module # macaron style feed forward module
src = src + self.feed_forward_macaron_scale(self.feed_forward_macaron(src), src = src + self.feed_forward_macaron_scale(self.feed_forward_macaron(src),
feature_mask) layerdrop_indicator)
# multi-headed self-attention module # multi-headed self-attention module
src_att, _, attn_scores_out = self.self_attn( src_att, _, attn_scores_out = self.self_attn(
@ -271,18 +295,18 @@ class ConformerEncoderLayer(nn.Module):
attn_mask=src_mask, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask, key_padding_mask=src_key_padding_mask,
) )
src = src + self.self_attn_scale(src_att, feature_mask) src = src + self.self_attn_scale(src_att, layerdrop_indicator)
# convolution module # convolution module
src = src + self.conv_scale(self.conv_module(src, src_key_padding_mask=src_key_padding_mask), src = src + self.conv_scale(self.conv_module(src, src_key_padding_mask=src_key_padding_mask),
feature_mask) layerdrop_indicator)
# feed forward module # feed forward module
src = src + self.feed_forward_scale(self.feed_forward(src), src = src + self.feed_forward_scale(self.feed_forward(src),
feature_mask) layerdrop_indicator)
src = self.final_scale(src, feature_mask) src = self.final_scale(src, layerdrop_indicator)
src = self.norm_final(self.balancer(src)) src = self.norm_final(self.balancer(src))
@ -312,8 +336,12 @@ class ConformerEncoder(nn.Module):
encoder_layer: nn.Module, encoder_layer: nn.Module,
num_layers: int, num_layers: int,
aux_layers: List[int], aux_layers: List[int],
layer_dropout: float = 0.333
) -> None: ) -> None:
super().__init__() super().__init__()
assert 0 < layer_dropout < 0.5
self.layer_dropout = layer_dropout
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)] [copy.deepcopy(encoder_layer) for i in range(num_layers)]
) )
@ -359,34 +387,56 @@ class ConformerEncoder(nn.Module):
attn_scores = None attn_scores = None
if self.training: # deal with feature masking.
if not self.training:
feature_mask = 1.0
else:
# feature mask. # feature mask.
# on 0.25 of the frames, drop out the extra features [force a bottleneck.] # on 0.25 of the frames, drop out the extra features [force a bottleneck.]
feature_mask_dropout_prob = 0.15 feature_mask_dropout_prob = 0.15
feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked. feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked.
full_feature_mask = torch.ones_like(src) # S, N, E feature_mask = torch.ones_like(src) # S, N, E
# feature_mask is 0 with probability `feature_mask_dropout_prob` # frame_mask is 0 with probability `feature_mask_dropout_prob`
# feature_mask shape: (S, N, 1) # frame_mask shape: (S, N, 1)
feature_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype) frame_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype)
full_feature_mask[..., feature_unmasked_dim:] *= feature_mask feature_mask[..., feature_unmasked_dim:] *= frame_mask
else:
feature_mask = 1.0
full_feature_mask = 1.0
src = src * full_feature_mask # deal with layer dropout.
batch_size = src.shape[1]
if not self.training or batch_size == 1:
dropped_layer_pairs = set() # empty set.
else:
num_layer_pairs = len(self.layers) // 2
layer_pairs = list(range(num_layer_pairs))
random.shuffle(layer_pairs)
# the * 2 is because we only drop out one layer from each pair:
# half for one half of the batch and the other half for the other.
num_dropped_pairs = int(self.layer_dropout * 2 * num_layer_pairs)
dropped_layer_pairs = set(layer_pairs[:num_dropped_pairs])
rand_bool = (random.random() < 0.5)
src = src * feature_mask
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
if i // 2 not in dropped_layer_pairs:
batch_split = None # no layer dropout
else:
batch_split = rand_bool if i % 2 == 0 else not rand_bool
output, attn_scores = mod( output, attn_scores = mod(
output, output,
feature_mask,
pos_emb, pos_emb,
attn_scores, attn_scores,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
warmup=warmup, warmup=warmup,
batch_split=batch_split,
) )
output = output * full_feature_mask output = output * feature_mask
if i in self.aux_layers: if i in self.aux_layers:
outputs.append(output) outputs.append(output)