mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement efficient layer dropout
This commit is contained in:
parent
93dff29243
commit
b3af9f67ae
@ -60,7 +60,7 @@ class Conformer(EncoderInterface):
|
|||||||
dim_feedforward: int = 2048,
|
dim_feedforward: int = 2048,
|
||||||
num_encoder_layers: int = 12,
|
num_encoder_layers: int = 12,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
layer_dropout: float = 0.075,
|
layer_dropout: float = 0.333,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
aux_layer_period: int = 3,
|
aux_layer_period: int = 3,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -85,13 +85,13 @@ class Conformer(EncoderInterface):
|
|||||||
nhead,
|
nhead,
|
||||||
dim_feedforward,
|
dim_feedforward,
|
||||||
dropout,
|
dropout,
|
||||||
layer_dropout,
|
|
||||||
cnn_module_kernel,
|
cnn_module_kernel,
|
||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(
|
self.encoder = ConformerEncoder(
|
||||||
encoder_layer,
|
encoder_layer,
|
||||||
num_encoder_layers,
|
num_encoder_layers,
|
||||||
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
|
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
|
||||||
|
layer_dropout=layer_dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -160,13 +160,10 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
nhead: int,
|
nhead: int,
|
||||||
dim_feedforward: int = 2048,
|
dim_feedforward: int = 2048,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
layer_dropout: float = 0.075,
|
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoderLayer, self).__init__()
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
|
|
||||||
self.layer_dropout = layer_dropout
|
|
||||||
|
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
|
|
||||||
self.self_attn = RelPositionMultiheadAttention(
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
@ -215,53 +212,80 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
feature_mask: Union[Tensor, float],
|
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
attn_scores_in: Optional[Tensor] = None,
|
attn_scores_in: Optional[Tensor] = None,
|
||||||
src_mask: Optional[Tensor] = None,
|
src_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
) -> Tensor:
|
batch_split: Optional[bool] = None,
|
||||||
|
layerdrop_indicator: float = 1.0,
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Pass the input through the encoder layer.
|
Pass the input through the encoder layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
src: the sequence to the encoder layer (required).
|
src: the sequence to the encoder layer (required).
|
||||||
pos_emb: Positional embedding tensor (required).
|
pos_emb: Positional embedding tensor (required).
|
||||||
attn_scores_in: something with the dimension fo attention weights (bsz * num_heads, len, len) that is
|
attn_scores_in: something with the dimension fo attention weights (bsz, len, len, num_heads) that is
|
||||||
passed from layer to layer.
|
passed from layer to layer.
|
||||||
src_mask: the mask for the src sequence (optional).
|
src_mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
feature_mask: a mask of shape (S, N, E), that randomly zeroes out
|
|
||||||
some of the features on each frame.
|
|
||||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||||
bypass layers more frequently.
|
bypass layers more frequently.
|
||||||
|
batch_split: if not None, this layer will only be applied to
|
||||||
|
part of the batch. if True we apply it to the first half of the batch
|
||||||
|
elements, otherwise to the second half.
|
||||||
|
layerdrop_indicator: a float. It is supposed to be 1.0 if nothing is dropped out,
|
||||||
|
and 0.0 if something is dropped out. You don't have to set this directly,
|
||||||
|
it is set internally if you provide the batch_split option as non-None.
|
||||||
|
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
feature_mask: float, or (S, N, 1)
|
|
||||||
pos_emb: (N, 2*S-1, E)
|
pos_emb: (N, 2*S-1, E)
|
||||||
src_mask: (S, S).
|
src_mask: (S, S).
|
||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
"""
|
"""
|
||||||
|
if batch_split is not None:
|
||||||
|
process_first_half = batch_split
|
||||||
|
batch_size = src.shape[1]
|
||||||
|
mid = batch_size // 2
|
||||||
|
|
||||||
|
if attn_scores_in is None:
|
||||||
|
seq_len = src.shape[0]
|
||||||
|
num_heads = self.self_attn.num_heads
|
||||||
|
attn_scores_in = torch.zeros(1, 1, 1, 1, device=src.device, dtype=src.dtype).expand(
|
||||||
|
batch_size, seq_len, seq_len, num_heads)
|
||||||
|
|
||||||
|
|
||||||
|
attn_scores_a, attn_scores_b = attn_scores_in[:mid], attn_scores_in[mid:]
|
||||||
|
src_a, src_b = src[:, :mid], src[:, mid:]
|
||||||
|
key_padding_a, key_padding_b = src_key_padding_mask[:mid], src_key_padding_mask[mid:],
|
||||||
|
|
||||||
|
if process_first_half:
|
||||||
|
src_a, attn_scores_a = self.forward(src_a, pos_emb, attn_scores_a, src_mask,
|
||||||
|
key_padding_a, warmup, batch_split=None,
|
||||||
|
layerdrop_indicator=0.0)
|
||||||
|
else:
|
||||||
|
src_b, attn_scores_b = self.forward(src_b, pos_emb, attn_scores_b, src_mask,
|
||||||
|
key_padding_b, warmup, batch_split=None,
|
||||||
|
layerdrop_indicator=0.0)
|
||||||
|
|
||||||
|
return torch.cat((src_a, src_b), dim=1), torch.cat((attn_scores_a, attn_scores_b), dim=0)
|
||||||
|
|
||||||
|
|
||||||
src_orig = src
|
src_orig = src
|
||||||
|
|
||||||
|
|
||||||
warmup_scale = min(0.1 + warmup, 1.0)
|
warmup_scale = min(0.1 + warmup, 1.0)
|
||||||
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
||||||
# completely bypass it.
|
# completely bypass it.
|
||||||
if self.training:
|
alpha = warmup_scale if self.training else 1.0
|
||||||
alpha = (
|
|
||||||
warmup_scale
|
|
||||||
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
|
|
||||||
else 0.1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
# macaron style feed forward module
|
# macaron style feed forward module
|
||||||
src = src + self.feed_forward_macaron_scale(self.feed_forward_macaron(src),
|
src = src + self.feed_forward_macaron_scale(self.feed_forward_macaron(src),
|
||||||
feature_mask)
|
layerdrop_indicator)
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
src_att, _, attn_scores_out = self.self_attn(
|
src_att, _, attn_scores_out = self.self_attn(
|
||||||
@ -271,18 +295,18 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
attn_mask=src_mask,
|
attn_mask=src_mask,
|
||||||
key_padding_mask=src_key_padding_mask,
|
key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
src = src + self.self_attn_scale(src_att, feature_mask)
|
src = src + self.self_attn_scale(src_att, layerdrop_indicator)
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
src = src + self.conv_scale(self.conv_module(src, src_key_padding_mask=src_key_padding_mask),
|
src = src + self.conv_scale(self.conv_module(src, src_key_padding_mask=src_key_padding_mask),
|
||||||
feature_mask)
|
layerdrop_indicator)
|
||||||
|
|
||||||
|
|
||||||
# feed forward module
|
# feed forward module
|
||||||
src = src + self.feed_forward_scale(self.feed_forward(src),
|
src = src + self.feed_forward_scale(self.feed_forward(src),
|
||||||
feature_mask)
|
layerdrop_indicator)
|
||||||
|
|
||||||
src = self.final_scale(src, feature_mask)
|
src = self.final_scale(src, layerdrop_indicator)
|
||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
@ -312,8 +336,12 @@ class ConformerEncoder(nn.Module):
|
|||||||
encoder_layer: nn.Module,
|
encoder_layer: nn.Module,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
aux_layers: List[int],
|
aux_layers: List[int],
|
||||||
|
layer_dropout: float = 0.333
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
assert 0 < layer_dropout < 0.5
|
||||||
|
self.layer_dropout = layer_dropout
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||||
)
|
)
|
||||||
@ -359,34 +387,56 @@ class ConformerEncoder(nn.Module):
|
|||||||
attn_scores = None
|
attn_scores = None
|
||||||
|
|
||||||
|
|
||||||
if self.training:
|
# deal with feature masking.
|
||||||
|
if not self.training:
|
||||||
|
feature_mask = 1.0
|
||||||
|
else:
|
||||||
# feature mask.
|
# feature mask.
|
||||||
# on 0.25 of the frames, drop out the extra features [force a bottleneck.]
|
# on 0.25 of the frames, drop out the extra features [force a bottleneck.]
|
||||||
feature_mask_dropout_prob = 0.15
|
feature_mask_dropout_prob = 0.15
|
||||||
feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked.
|
feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked.
|
||||||
|
|
||||||
full_feature_mask = torch.ones_like(src) # S, N, E
|
feature_mask = torch.ones_like(src) # S, N, E
|
||||||
# feature_mask is 0 with probability `feature_mask_dropout_prob`
|
# frame_mask is 0 with probability `feature_mask_dropout_prob`
|
||||||
# feature_mask shape: (S, N, 1)
|
# frame_mask shape: (S, N, 1)
|
||||||
feature_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype)
|
frame_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype)
|
||||||
full_feature_mask[..., feature_unmasked_dim:] *= feature_mask
|
feature_mask[..., feature_unmasked_dim:] *= frame_mask
|
||||||
else:
|
|
||||||
feature_mask = 1.0
|
|
||||||
full_feature_mask = 1.0
|
|
||||||
|
|
||||||
src = src * full_feature_mask
|
# deal with layer dropout.
|
||||||
|
batch_size = src.shape[1]
|
||||||
|
if not self.training or batch_size == 1:
|
||||||
|
dropped_layer_pairs = set() # empty set.
|
||||||
|
else:
|
||||||
|
num_layer_pairs = len(self.layers) // 2
|
||||||
|
layer_pairs = list(range(num_layer_pairs))
|
||||||
|
random.shuffle(layer_pairs)
|
||||||
|
# the * 2 is because we only drop out one layer from each pair:
|
||||||
|
# half for one half of the batch and the other half for the other.
|
||||||
|
num_dropped_pairs = int(self.layer_dropout * 2 * num_layer_pairs)
|
||||||
|
dropped_layer_pairs = set(layer_pairs[:num_dropped_pairs])
|
||||||
|
|
||||||
|
|
||||||
|
rand_bool = (random.random() < 0.5)
|
||||||
|
|
||||||
|
src = src * feature_mask
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
|
|
||||||
|
if i // 2 not in dropped_layer_pairs:
|
||||||
|
batch_split = None # no layer dropout
|
||||||
|
else:
|
||||||
|
batch_split = rand_bool if i % 2 == 0 else not rand_bool
|
||||||
|
|
||||||
output, attn_scores = mod(
|
output, attn_scores = mod(
|
||||||
output,
|
output,
|
||||||
feature_mask,
|
|
||||||
pos_emb,
|
pos_emb,
|
||||||
attn_scores,
|
attn_scores,
|
||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
|
batch_split=batch_split,
|
||||||
)
|
)
|
||||||
output = output * full_feature_mask
|
output = output * feature_mask
|
||||||
if i in self.aux_layers:
|
if i in self.aux_layers:
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user