Remove final combination; implement layer drop that drops the final layers.

This commit is contained in:
Daniel Povey 2022-10-04 22:19:44 +08:00
parent 006fcc18cd
commit 5fe8cb134f
2 changed files with 61 additions and 112 deletions

View File

@ -62,7 +62,6 @@ class Conformer(EncoderInterface):
dropout: float = 0.1, dropout: float = 0.1,
layer_dropout: float = 0.25, layer_dropout: float = 0.25,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
aux_layer_period: int = 3,
) -> None: ) -> None:
super(Conformer, self).__init__() super(Conformer, self).__init__()
@ -90,7 +89,6 @@ class Conformer(EncoderInterface):
self.encoder = ConformerEncoder( self.encoder = ConformerEncoder(
encoder_layer, encoder_layer,
num_encoder_layers, num_encoder_layers,
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
layer_dropout=layer_dropout, layer_dropout=layer_dropout,
) )
@ -210,8 +208,7 @@ class ConformerEncoderLayer(nn.Module):
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0, warmup: float = 1.0,
layerdrop_mask: Optional[List[float]] = None, layerdrop_scale: Optional[Tensor] = None,
layerdrop_scales: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
@ -226,11 +223,7 @@ class ConformerEncoderLayer(nn.Module):
warmup: controls selective bypass of of layers; if < 1.0, we will warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently. bypass layers more frequently.
batch_split: if not None, this layer will only be applied to batch_split: if not None, this layer will only be applied to
layerdrop_mask: if None or [1.0, 1.0] then we do the computation as normal. If layerdrop_scale: an optional Tensor of broadcasting with `src` that will be used as a scale
[1.0, 0.0] or [0.0, 1.0], we will only do this computation for the first or
second half of the batch respectively, and just copy the input for the other
half.
layerdrop_scales: an optional Tensor of shape (batch_size, 1) that will be used as a scale
on the change in the embeddings made by this layer. on the change in the embeddings made by this layer.
Shape: Shape:
@ -240,40 +233,8 @@ class ConformerEncoderLayer(nn.Module):
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number S is the source sequence length, N is the batch size, E is the feature number
""" """
if layerdrop_mask not in [ None, [1.0, 1.0] ]:
assert layerdrop_mask in [ [1.0, 0.0], [0.0, 1.0] ]
process_first_half = (layerdrop_mask == [1.0, 0.0])
batch_size = src.shape[1]
mid = batch_size // 2
if attn_scores_in is None:
seq_len = src.shape[0]
num_heads = self.self_attn.num_heads
attn_scores_in = torch.zeros(1, 1, 1, 1, device=src.device, dtype=src.dtype).expand(
batch_size, seq_len, seq_len, num_heads)
attn_scores_a, attn_scores_b = attn_scores_in[:mid], attn_scores_in[mid:]
src_a, src_b = src[:, :mid], src[:, mid:]
key_padding_a, key_padding_b = src_key_padding_mask[:mid], src_key_padding_mask[mid:],
layerdrop_scales_a, layerdrop_scales_b = layerdrop_scales[:mid], layerdrop_scales[mid:]
if process_first_half:
src_a, attn_scores_a = self.forward(src_a, pos_emb, attn_scores_a, src_mask,
key_padding_a, warmup,
layerdrop_mask=None,
layerdrop_scales=layerdrop_scales_a)
else:
src_b, attn_scores_b = self.forward(src_b, pos_emb, attn_scores_b, src_mask,
key_padding_b, warmup,
layerdrop_mask=None,
layerdrop_scales=layerdrop_scales_b)
return torch.cat((src_a, src_b), dim=1), torch.cat((attn_scores_a, attn_scores_b), dim=0)
src_orig = src src_orig = src
warmup_scale = min(0.1 + warmup, 1.0) warmup_scale = min(0.1 + warmup, 1.0)
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
# completely bypass it. # completely bypass it.
@ -301,11 +262,11 @@ class ConformerEncoderLayer(nn.Module):
src = self.norm_final(self.balancer(src)) src = self.norm_final(self.balancer(src))
if alpha != 1.0 or layerdrop_scales is not None: if alpha != 1.0 or layerdrop_scale is not None:
# the if(self.training) part is to ensure we have a derivative for # the if(self.training) part is to ensure we have a derivative for
# self.scale_alpha. # self.scale_alpha.
src_offset = src - src_orig src_offset = src - src_orig
scale = alpha * (1.0 if layerdrop_scales is None else layerdrop_scales) scale = alpha * layerdrop_scale
src = src_orig + src_offset * scale src = src_orig + src_offset * scale
return src, attn_scores_out return src, attn_scores_out
@ -325,16 +286,18 @@ class ConformerEncoder(nn.Module):
>>> pos_emb = torch.rand(32, 19, 512) >>> pos_emb = torch.rand(32, 19, 512)
>>> out = conformer_encoder(src, pos_emb) >>> out = conformer_encoder(src, pos_emb)
""" """
def __init__( def __init__(
self, self,
encoder_layer: nn.Module, encoder_layer: nn.Module,
num_layers: int, num_layers: int,
aux_layers: List[int],
layer_dropout: float = 0.25 layer_dropout: float = 0.25
) -> None: ) -> None:
super().__init__() super().__init__()
assert 0 < layer_dropout < 0.5 assert 0 < layer_dropout < 0.5
# `count` tracks how many times the forward function has been called
# since we initialized the model (it is not written to disk or read when
# we resume training). It is used for random seeding for layer dropping.
self.count = 0
self.layer_dropout = layer_dropout self.layer_dropout = layer_dropout
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
@ -342,31 +305,25 @@ class ConformerEncoder(nn.Module):
) )
self.num_layers = num_layers self.num_layers = num_layers
self.layerdrop_scale_mat = nn.Parameter(0.01 * torch.randn(num_layers, num_layers)) self.to_layerdrop_scales = nn.Sequential(
self.layerdrop_scale_offset = nn.Parameter(torch.ones(num_layers)) ScaledLinear(num_layers, 256, initial_scale=0.5),
nn.ReLU(),
ScaledLinear(256, num_layers, initial_scale=0.01))
assert num_layers - 1 not in aux_layers
self.aux_layers = set(aux_layers + [num_layers - 1])
num_channels = encoder_layer.norm_final.num_channels num_channels = encoder_layer.norm_final.num_channels
self.combiner = AttentionCombine(
num_channels=encoder_layer.d_model,
num_inputs=len(self.aux_layers),
random_prob=0.333,
)
def get_layerdrop_info(self, def get_layerdrop_info(self) -> Tuple[Tensor, Optional[Tensor]]:
batch_size: int) -> Tuple[Tensor, Optional[Tensor]]:
""" """
Gets some random information that dictates layer dropout configuration. Gets some random information that dictates layer dropout configuration.
Args: Args:
batch_size: the number of sequences in the batch batch_size: the number of sequences in the batch
Returns: Returns:
(layerdrop_mask, layerdrop_scales) (mask, layerdrop_scales)
where: where:
layerdrop_mask is a CPU tensor of shape (num_layers, 2) where the 2 represents layerdrop_mask is a CPU tensor of shape (num_layers,),
two halves of the batch, containing 1.0 for positions to be evaluated and 0.0 containing 1.0 for layers to be evaluated and 0.0
for positions not to be evaluated. It has constraints: at least one of two for layers not to be evaluated. It has constraints: at least one of two
halves of each layer must be evaluated, and successive layers of the same half halves of each layer must be evaluated, and successive layers of the same half
pmust be evaluated. pmust be evaluated.
@ -382,42 +339,39 @@ class ConformerEncoder(nn.Module):
""" """
num_layers = self.num_layers num_layers = self.num_layers
layerdrop_mask = torch.ones(num_layers, 2, device='cpu') # This ensures that if we are using multiple worker processes, they all use the same
# random numbers, so they will all take about the same amount of time to process
# the batch.
r = random.Random(self.count)
self.count += 1
if self.training and batch_size != 1: def get_random_mask():
halves_to_drop = int(2 * num_layers * self.layer_dropout) # 1.0 means don't drop the layer, 0.0 means drop the layer
for _ in range(halves_to_drop): mask = torch.ones(num_layers, device='cpu')
while True: if self.training:
r = random.randrange(0, 2 * num_layers) return mask
i = r // 2 r = r.random()
j = r % 2 if r < 0.1:
if layerdrop_mask[i, j - 1] == 0.0: # drop zero layers, to make sure that sometimes we see the complete network.
# This position cannot be set to 0.0 because the other return mask
# half of the batch is already 0.0 (not computed). This would lead to final_layers_dropped = 0
# one layer not having a gradient. if r < 0.1 + 0.25:
continue # with prob 0.25: completely drop the last n layers. let n
if ((i > 0 and layerdrop_mask[i-1, j] == 0.0) or # be a multiple of 3 (this is what we used to do with aux_layers).
(i + 1 < num_layers and layerdrop_mask[i+1, j] == 0.0)): final_layers_dropped = 3 * r.randint(1, num_layers // 3)
# This position cannot be set to False because the preceding mask[-final_layers_dropped:] = 0.0
# or following position for this same half of the batch is
# already set to False
continue
layerdrop_mask[i, j] = 0.0
break
# layerdrop_scales: currently shape is (2, num_layers) layer_drop_prob = 0.075
device = self.layerdrop_scale_mat.device for i in range(final_layers_dropped):
layerdrop_scales_tmp = (self.layerdrop_scale_offset.unsqueeze(1) + mask[i] = (r.random() > layer_drop_prob)
torch.matmul(self.layerdrop_scale_mat,
1.0 - layerdrop_mask.to(device)))
layerdrop_scales = torch.empty(num_layers, batch_size, 1, device=device) if mask.sum() == 0.0:
mid = batch_size // 2 mask[0] = 1.0
mask = get_random_mask()
device = self.to_layerdrop_scales[0].weight.device
layerdrop_scales = 1.0 + self.to_layerdrop_scales(mask.to(device))
return mask, layerdrop_scales
layerdrop_scales[:, :mid, 0] = layerdrop_scales_tmp[:,0:1] # shape: (num_layers, 1)
layerdrop_scales[:, mid:, 0] = layerdrop_scales_tmp[:,1:2] # shape: (num_layers, 1)
return layerdrop_mask, layerdrop_scales
def forward( def forward(
@ -466,30 +420,24 @@ class ConformerEncoder(nn.Module):
feature_mask[..., feature_unmasked_dim:] *= frame_mask feature_mask[..., feature_unmasked_dim:] *= frame_mask
# deal with layer dropout. # deal with layer dropout.
layerdrop_mask, layerdrop_scales = self.get_layerdrop_info(batch_size=src.shape[1]) layerdrop_mask, layerdrop_scales = self.get_layerdrop_info()
src = src * feature_mask output = output * feature_mask
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
output, attn_scores = mod( if layerdrop_mask[i] != 0.0:
output, output, attn_scores = mod(
pos_emb, output,
attn_scores, pos_emb,
src_mask=mask, attn_scores,
src_key_padding_mask=src_key_padding_mask, src_mask=mask,
warmup=warmup, src_key_padding_mask=src_key_padding_mask,
layerdrop_mask=layerdrop_mask[i].tolist(), # [ 1.0, 1.0 ], [0.0, 1.0] or [1.0, 0.0] warmup=warmup,
layerdrop_scales=layerdrop_scales[i], # tensor of scales of shape (batch_size, 1) layerdrop_scale=layerdrop_scales[i],
) )
output = output * feature_mask output = output * feature_mask
if i in self.aux_layers:
outputs.append(output)
output = self.combiner(outputs)
output = output + 0.0 * attn_scores.sum() # just ensure attn_scores is used in backprop
return output return output

View File

@ -924,7 +924,8 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank],
find_unused_parameters=True)
optimizer = ScaledAdam(model.parameters(), optimizer = ScaledAdam(model.parameters(),
lr=params.initial_lr, lr=params.initial_lr,