Remove final combination; implement layer drop that drops the final layers.

This commit is contained in:
Daniel Povey 2022-10-04 22:19:44 +08:00
parent 006fcc18cd
commit 5fe8cb134f
2 changed files with 61 additions and 112 deletions

View File

@ -62,7 +62,6 @@ class Conformer(EncoderInterface):
dropout: float = 0.1,
layer_dropout: float = 0.25,
cnn_module_kernel: int = 31,
aux_layer_period: int = 3,
) -> None:
super(Conformer, self).__init__()
@ -90,7 +89,6 @@ class Conformer(EncoderInterface):
self.encoder = ConformerEncoder(
encoder_layer,
num_encoder_layers,
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
layer_dropout=layer_dropout,
)
@ -210,8 +208,7 @@ class ConformerEncoderLayer(nn.Module):
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
layerdrop_mask: Optional[List[float]] = None,
layerdrop_scales: Optional[Tensor] = None,
layerdrop_scale: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""
Pass the input through the encoder layer.
@ -226,11 +223,7 @@ class ConformerEncoderLayer(nn.Module):
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
batch_split: if not None, this layer will only be applied to
layerdrop_mask: if None or [1.0, 1.0] then we do the computation as normal. If
[1.0, 0.0] or [0.0, 1.0], we will only do this computation for the first or
second half of the batch respectively, and just copy the input for the other
half.
layerdrop_scales: an optional Tensor of shape (batch_size, 1) that will be used as a scale
layerdrop_scale: an optional Tensor of broadcasting with `src` that will be used as a scale
on the change in the embeddings made by this layer.
Shape:
@ -240,40 +233,8 @@ class ConformerEncoderLayer(nn.Module):
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
if layerdrop_mask not in [ None, [1.0, 1.0] ]:
assert layerdrop_mask in [ [1.0, 0.0], [0.0, 1.0] ]
process_first_half = (layerdrop_mask == [1.0, 0.0])
batch_size = src.shape[1]
mid = batch_size // 2
if attn_scores_in is None:
seq_len = src.shape[0]
num_heads = self.self_attn.num_heads
attn_scores_in = torch.zeros(1, 1, 1, 1, device=src.device, dtype=src.dtype).expand(
batch_size, seq_len, seq_len, num_heads)
attn_scores_a, attn_scores_b = attn_scores_in[:mid], attn_scores_in[mid:]
src_a, src_b = src[:, :mid], src[:, mid:]
key_padding_a, key_padding_b = src_key_padding_mask[:mid], src_key_padding_mask[mid:],
layerdrop_scales_a, layerdrop_scales_b = layerdrop_scales[:mid], layerdrop_scales[mid:]
if process_first_half:
src_a, attn_scores_a = self.forward(src_a, pos_emb, attn_scores_a, src_mask,
key_padding_a, warmup,
layerdrop_mask=None,
layerdrop_scales=layerdrop_scales_a)
else:
src_b, attn_scores_b = self.forward(src_b, pos_emb, attn_scores_b, src_mask,
key_padding_b, warmup,
layerdrop_mask=None,
layerdrop_scales=layerdrop_scales_b)
return torch.cat((src_a, src_b), dim=1), torch.cat((attn_scores_a, attn_scores_b), dim=0)
src_orig = src
warmup_scale = min(0.1 + warmup, 1.0)
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
# completely bypass it.
@ -301,11 +262,11 @@ class ConformerEncoderLayer(nn.Module):
src = self.norm_final(self.balancer(src))
if alpha != 1.0 or layerdrop_scales is not None:
if alpha != 1.0 or layerdrop_scale is not None:
# the if(self.training) part is to ensure we have a derivative for
# self.scale_alpha.
src_offset = src - src_orig
scale = alpha * (1.0 if layerdrop_scales is None else layerdrop_scales)
scale = alpha * layerdrop_scale
src = src_orig + src_offset * scale
return src, attn_scores_out
@ -325,16 +286,18 @@ class ConformerEncoder(nn.Module):
>>> pos_emb = torch.rand(32, 19, 512)
>>> out = conformer_encoder(src, pos_emb)
"""
def __init__(
self,
encoder_layer: nn.Module,
num_layers: int,
aux_layers: List[int],
layer_dropout: float = 0.25
) -> None:
super().__init__()
assert 0 < layer_dropout < 0.5
# `count` tracks how many times the forward function has been called
# since we initialized the model (it is not written to disk or read when
# we resume training). It is used for random seeding for layer dropping.
self.count = 0
self.layer_dropout = layer_dropout
self.layers = nn.ModuleList(
@ -342,31 +305,25 @@ class ConformerEncoder(nn.Module):
)
self.num_layers = num_layers
self.layerdrop_scale_mat = nn.Parameter(0.01 * torch.randn(num_layers, num_layers))
self.layerdrop_scale_offset = nn.Parameter(torch.ones(num_layers))
self.to_layerdrop_scales = nn.Sequential(
ScaledLinear(num_layers, 256, initial_scale=0.5),
nn.ReLU(),
ScaledLinear(256, num_layers, initial_scale=0.01))
assert num_layers - 1 not in aux_layers
self.aux_layers = set(aux_layers + [num_layers - 1])
num_channels = encoder_layer.norm_final.num_channels
self.combiner = AttentionCombine(
num_channels=encoder_layer.d_model,
num_inputs=len(self.aux_layers),
random_prob=0.333,
)
def get_layerdrop_info(self,
batch_size: int) -> Tuple[Tensor, Optional[Tensor]]:
def get_layerdrop_info(self) -> Tuple[Tensor, Optional[Tensor]]:
"""
Gets some random information that dictates layer dropout configuration.
Args:
batch_size: the number of sequences in the batch
Returns:
(layerdrop_mask, layerdrop_scales)
(mask, layerdrop_scales)
where:
layerdrop_mask is a CPU tensor of shape (num_layers, 2) where the 2 represents
two halves of the batch, containing 1.0 for positions to be evaluated and 0.0
for positions not to be evaluated. It has constraints: at least one of two
layerdrop_mask is a CPU tensor of shape (num_layers,),
containing 1.0 for layers to be evaluated and 0.0
for layers not to be evaluated. It has constraints: at least one of two
halves of each layer must be evaluated, and successive layers of the same half
pmust be evaluated.
@ -382,42 +339,39 @@ class ConformerEncoder(nn.Module):
"""
num_layers = self.num_layers
layerdrop_mask = torch.ones(num_layers, 2, device='cpu')
# This ensures that if we are using multiple worker processes, they all use the same
# random numbers, so they will all take about the same amount of time to process
# the batch.
r = random.Random(self.count)
self.count += 1
if self.training and batch_size != 1:
halves_to_drop = int(2 * num_layers * self.layer_dropout)
for _ in range(halves_to_drop):
while True:
r = random.randrange(0, 2 * num_layers)
i = r // 2
j = r % 2
if layerdrop_mask[i, j - 1] == 0.0:
# This position cannot be set to 0.0 because the other
# half of the batch is already 0.0 (not computed). This would lead to
# one layer not having a gradient.
continue
if ((i > 0 and layerdrop_mask[i-1, j] == 0.0) or
(i + 1 < num_layers and layerdrop_mask[i+1, j] == 0.0)):
# This position cannot be set to False because the preceding
# or following position for this same half of the batch is
# already set to False
continue
layerdrop_mask[i, j] = 0.0
break
def get_random_mask():
# 1.0 means don't drop the layer, 0.0 means drop the layer
mask = torch.ones(num_layers, device='cpu')
if self.training:
return mask
r = r.random()
if r < 0.1:
# drop zero layers, to make sure that sometimes we see the complete network.
return mask
final_layers_dropped = 0
if r < 0.1 + 0.25:
# with prob 0.25: completely drop the last n layers. let n
# be a multiple of 3 (this is what we used to do with aux_layers).
final_layers_dropped = 3 * r.randint(1, num_layers // 3)
mask[-final_layers_dropped:] = 0.0
# layerdrop_scales: currently shape is (2, num_layers)
device = self.layerdrop_scale_mat.device
layerdrop_scales_tmp = (self.layerdrop_scale_offset.unsqueeze(1) +
torch.matmul(self.layerdrop_scale_mat,
1.0 - layerdrop_mask.to(device)))
layer_drop_prob = 0.075
for i in range(final_layers_dropped):
mask[i] = (r.random() > layer_drop_prob)
layerdrop_scales = torch.empty(num_layers, batch_size, 1, device=device)
mid = batch_size // 2
if mask.sum() == 0.0:
mask[0] = 1.0
mask = get_random_mask()
device = self.to_layerdrop_scales[0].weight.device
layerdrop_scales = 1.0 + self.to_layerdrop_scales(mask.to(device))
return mask, layerdrop_scales
layerdrop_scales[:, :mid, 0] = layerdrop_scales_tmp[:,0:1] # shape: (num_layers, 1)
layerdrop_scales[:, mid:, 0] = layerdrop_scales_tmp[:,1:2] # shape: (num_layers, 1)
return layerdrop_mask, layerdrop_scales
def forward(
@ -466,13 +420,14 @@ class ConformerEncoder(nn.Module):
feature_mask[..., feature_unmasked_dim:] *= frame_mask
# deal with layer dropout.
layerdrop_mask, layerdrop_scales = self.get_layerdrop_info(batch_size=src.shape[1])
layerdrop_mask, layerdrop_scales = self.get_layerdrop_info()
src = src * feature_mask
output = output * feature_mask
for i, mod in enumerate(self.layers):
if layerdrop_mask[i] != 0.0:
output, attn_scores = mod(
output,
pos_emb,
@ -480,16 +435,9 @@ class ConformerEncoder(nn.Module):
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
layerdrop_mask=layerdrop_mask[i].tolist(), # [ 1.0, 1.0 ], [0.0, 1.0] or [1.0, 0.0]
layerdrop_scales=layerdrop_scales[i], # tensor of scales of shape (batch_size, 1)
layerdrop_scale=layerdrop_scales[i],
)
output = output * feature_mask
if i in self.aux_layers:
outputs.append(output)
output = self.combiner(outputs)
output = output + 0.0 * attn_scores.sum() # just ensure attn_scores is used in backprop
return output

View File

@ -924,7 +924,8 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank])
model = DDP(model, device_ids=[rank],
find_unused_parameters=True)
optimizer = ScaledAdam(model.parameters(),
lr=params.initial_lr,