mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove final combination; implement layer drop that drops the final layers.
This commit is contained in:
parent
006fcc18cd
commit
5fe8cb134f
@ -62,7 +62,6 @@ class Conformer(EncoderInterface):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
layer_dropout: float = 0.25,
|
layer_dropout: float = 0.25,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
aux_layer_period: int = 3,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__()
|
super(Conformer, self).__init__()
|
||||||
|
|
||||||
@ -90,7 +89,6 @@ class Conformer(EncoderInterface):
|
|||||||
self.encoder = ConformerEncoder(
|
self.encoder = ConformerEncoder(
|
||||||
encoder_layer,
|
encoder_layer,
|
||||||
num_encoder_layers,
|
num_encoder_layers,
|
||||||
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
|
|
||||||
layer_dropout=layer_dropout,
|
layer_dropout=layer_dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -210,8 +208,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src_mask: Optional[Tensor] = None,
|
src_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
layerdrop_mask: Optional[List[float]] = None,
|
layerdrop_scale: Optional[Tensor] = None,
|
||||||
layerdrop_scales: Optional[Tensor] = None,
|
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Pass the input through the encoder layer.
|
Pass the input through the encoder layer.
|
||||||
@ -226,11 +223,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||||
bypass layers more frequently.
|
bypass layers more frequently.
|
||||||
batch_split: if not None, this layer will only be applied to
|
batch_split: if not None, this layer will only be applied to
|
||||||
layerdrop_mask: if None or [1.0, 1.0] then we do the computation as normal. If
|
layerdrop_scale: an optional Tensor of broadcasting with `src` that will be used as a scale
|
||||||
[1.0, 0.0] or [0.0, 1.0], we will only do this computation for the first or
|
|
||||||
second half of the batch respectively, and just copy the input for the other
|
|
||||||
half.
|
|
||||||
layerdrop_scales: an optional Tensor of shape (batch_size, 1) that will be used as a scale
|
|
||||||
on the change in the embeddings made by this layer.
|
on the change in the embeddings made by this layer.
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
@ -240,40 +233,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
"""
|
"""
|
||||||
if layerdrop_mask not in [ None, [1.0, 1.0] ]:
|
|
||||||
assert layerdrop_mask in [ [1.0, 0.0], [0.0, 1.0] ]
|
|
||||||
process_first_half = (layerdrop_mask == [1.0, 0.0])
|
|
||||||
batch_size = src.shape[1]
|
|
||||||
mid = batch_size // 2
|
|
||||||
|
|
||||||
if attn_scores_in is None:
|
|
||||||
seq_len = src.shape[0]
|
|
||||||
num_heads = self.self_attn.num_heads
|
|
||||||
attn_scores_in = torch.zeros(1, 1, 1, 1, device=src.device, dtype=src.dtype).expand(
|
|
||||||
batch_size, seq_len, seq_len, num_heads)
|
|
||||||
|
|
||||||
attn_scores_a, attn_scores_b = attn_scores_in[:mid], attn_scores_in[mid:]
|
|
||||||
src_a, src_b = src[:, :mid], src[:, mid:]
|
|
||||||
key_padding_a, key_padding_b = src_key_padding_mask[:mid], src_key_padding_mask[mid:],
|
|
||||||
layerdrop_scales_a, layerdrop_scales_b = layerdrop_scales[:mid], layerdrop_scales[mid:]
|
|
||||||
|
|
||||||
if process_first_half:
|
|
||||||
src_a, attn_scores_a = self.forward(src_a, pos_emb, attn_scores_a, src_mask,
|
|
||||||
key_padding_a, warmup,
|
|
||||||
layerdrop_mask=None,
|
|
||||||
layerdrop_scales=layerdrop_scales_a)
|
|
||||||
else:
|
|
||||||
src_b, attn_scores_b = self.forward(src_b, pos_emb, attn_scores_b, src_mask,
|
|
||||||
key_padding_b, warmup,
|
|
||||||
layerdrop_mask=None,
|
|
||||||
layerdrop_scales=layerdrop_scales_b)
|
|
||||||
|
|
||||||
return torch.cat((src_a, src_b), dim=1), torch.cat((attn_scores_a, attn_scores_b), dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
src_orig = src
|
src_orig = src
|
||||||
|
|
||||||
|
|
||||||
warmup_scale = min(0.1 + warmup, 1.0)
|
warmup_scale = min(0.1 + warmup, 1.0)
|
||||||
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
||||||
# completely bypass it.
|
# completely bypass it.
|
||||||
@ -301,11 +262,11 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
if alpha != 1.0 or layerdrop_scales is not None:
|
if alpha != 1.0 or layerdrop_scale is not None:
|
||||||
# the if(self.training) part is to ensure we have a derivative for
|
# the if(self.training) part is to ensure we have a derivative for
|
||||||
# self.scale_alpha.
|
# self.scale_alpha.
|
||||||
src_offset = src - src_orig
|
src_offset = src - src_orig
|
||||||
scale = alpha * (1.0 if layerdrop_scales is None else layerdrop_scales)
|
scale = alpha * layerdrop_scale
|
||||||
src = src_orig + src_offset * scale
|
src = src_orig + src_offset * scale
|
||||||
|
|
||||||
return src, attn_scores_out
|
return src, attn_scores_out
|
||||||
@ -325,16 +286,18 @@ class ConformerEncoder(nn.Module):
|
|||||||
>>> pos_emb = torch.rand(32, 19, 512)
|
>>> pos_emb = torch.rand(32, 19, 512)
|
||||||
>>> out = conformer_encoder(src, pos_emb)
|
>>> out = conformer_encoder(src, pos_emb)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder_layer: nn.Module,
|
encoder_layer: nn.Module,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
aux_layers: List[int],
|
|
||||||
layer_dropout: float = 0.25
|
layer_dropout: float = 0.25
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert 0 < layer_dropout < 0.5
|
assert 0 < layer_dropout < 0.5
|
||||||
|
# `count` tracks how many times the forward function has been called
|
||||||
|
# since we initialized the model (it is not written to disk or read when
|
||||||
|
# we resume training). It is used for random seeding for layer dropping.
|
||||||
|
self.count = 0
|
||||||
self.layer_dropout = layer_dropout
|
self.layer_dropout = layer_dropout
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
@ -342,31 +305,25 @@ class ConformerEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
self.layerdrop_scale_mat = nn.Parameter(0.01 * torch.randn(num_layers, num_layers))
|
self.to_layerdrop_scales = nn.Sequential(
|
||||||
self.layerdrop_scale_offset = nn.Parameter(torch.ones(num_layers))
|
ScaledLinear(num_layers, 256, initial_scale=0.5),
|
||||||
|
nn.ReLU(),
|
||||||
|
ScaledLinear(256, num_layers, initial_scale=0.01))
|
||||||
|
|
||||||
assert num_layers - 1 not in aux_layers
|
|
||||||
self.aux_layers = set(aux_layers + [num_layers - 1])
|
|
||||||
|
|
||||||
num_channels = encoder_layer.norm_final.num_channels
|
num_channels = encoder_layer.norm_final.num_channels
|
||||||
self.combiner = AttentionCombine(
|
|
||||||
num_channels=encoder_layer.d_model,
|
|
||||||
num_inputs=len(self.aux_layers),
|
|
||||||
random_prob=0.333,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_layerdrop_info(self,
|
def get_layerdrop_info(self) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
batch_size: int) -> Tuple[Tensor, Optional[Tensor]]:
|
|
||||||
"""
|
"""
|
||||||
Gets some random information that dictates layer dropout configuration.
|
Gets some random information that dictates layer dropout configuration.
|
||||||
Args:
|
Args:
|
||||||
batch_size: the number of sequences in the batch
|
batch_size: the number of sequences in the batch
|
||||||
Returns:
|
Returns:
|
||||||
(layerdrop_mask, layerdrop_scales)
|
(mask, layerdrop_scales)
|
||||||
where:
|
where:
|
||||||
layerdrop_mask is a CPU tensor of shape (num_layers, 2) where the 2 represents
|
layerdrop_mask is a CPU tensor of shape (num_layers,),
|
||||||
two halves of the batch, containing 1.0 for positions to be evaluated and 0.0
|
containing 1.0 for layers to be evaluated and 0.0
|
||||||
for positions not to be evaluated. It has constraints: at least one of two
|
for layers not to be evaluated. It has constraints: at least one of two
|
||||||
halves of each layer must be evaluated, and successive layers of the same half
|
halves of each layer must be evaluated, and successive layers of the same half
|
||||||
pmust be evaluated.
|
pmust be evaluated.
|
||||||
|
|
||||||
@ -382,42 +339,39 @@ class ConformerEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
num_layers = self.num_layers
|
num_layers = self.num_layers
|
||||||
|
|
||||||
layerdrop_mask = torch.ones(num_layers, 2, device='cpu')
|
# This ensures that if we are using multiple worker processes, they all use the same
|
||||||
|
# random numbers, so they will all take about the same amount of time to process
|
||||||
|
# the batch.
|
||||||
|
r = random.Random(self.count)
|
||||||
|
self.count += 1
|
||||||
|
|
||||||
if self.training and batch_size != 1:
|
def get_random_mask():
|
||||||
halves_to_drop = int(2 * num_layers * self.layer_dropout)
|
# 1.0 means don't drop the layer, 0.0 means drop the layer
|
||||||
for _ in range(halves_to_drop):
|
mask = torch.ones(num_layers, device='cpu')
|
||||||
while True:
|
if self.training:
|
||||||
r = random.randrange(0, 2 * num_layers)
|
return mask
|
||||||
i = r // 2
|
r = r.random()
|
||||||
j = r % 2
|
if r < 0.1:
|
||||||
if layerdrop_mask[i, j - 1] == 0.0:
|
# drop zero layers, to make sure that sometimes we see the complete network.
|
||||||
# This position cannot be set to 0.0 because the other
|
return mask
|
||||||
# half of the batch is already 0.0 (not computed). This would lead to
|
final_layers_dropped = 0
|
||||||
# one layer not having a gradient.
|
if r < 0.1 + 0.25:
|
||||||
continue
|
# with prob 0.25: completely drop the last n layers. let n
|
||||||
if ((i > 0 and layerdrop_mask[i-1, j] == 0.0) or
|
# be a multiple of 3 (this is what we used to do with aux_layers).
|
||||||
(i + 1 < num_layers and layerdrop_mask[i+1, j] == 0.0)):
|
final_layers_dropped = 3 * r.randint(1, num_layers // 3)
|
||||||
# This position cannot be set to False because the preceding
|
mask[-final_layers_dropped:] = 0.0
|
||||||
# or following position for this same half of the batch is
|
|
||||||
# already set to False
|
|
||||||
continue
|
|
||||||
layerdrop_mask[i, j] = 0.0
|
|
||||||
break
|
|
||||||
|
|
||||||
# layerdrop_scales: currently shape is (2, num_layers)
|
layer_drop_prob = 0.075
|
||||||
device = self.layerdrop_scale_mat.device
|
for i in range(final_layers_dropped):
|
||||||
layerdrop_scales_tmp = (self.layerdrop_scale_offset.unsqueeze(1) +
|
mask[i] = (r.random() > layer_drop_prob)
|
||||||
torch.matmul(self.layerdrop_scale_mat,
|
|
||||||
1.0 - layerdrop_mask.to(device)))
|
|
||||||
|
|
||||||
layerdrop_scales = torch.empty(num_layers, batch_size, 1, device=device)
|
if mask.sum() == 0.0:
|
||||||
mid = batch_size // 2
|
mask[0] = 1.0
|
||||||
|
mask = get_random_mask()
|
||||||
|
device = self.to_layerdrop_scales[0].weight.device
|
||||||
|
layerdrop_scales = 1.0 + self.to_layerdrop_scales(mask.to(device))
|
||||||
|
return mask, layerdrop_scales
|
||||||
|
|
||||||
layerdrop_scales[:, :mid, 0] = layerdrop_scales_tmp[:,0:1] # shape: (num_layers, 1)
|
|
||||||
layerdrop_scales[:, mid:, 0] = layerdrop_scales_tmp[:,1:2] # shape: (num_layers, 1)
|
|
||||||
|
|
||||||
return layerdrop_mask, layerdrop_scales
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -466,30 +420,24 @@ class ConformerEncoder(nn.Module):
|
|||||||
feature_mask[..., feature_unmasked_dim:] *= frame_mask
|
feature_mask[..., feature_unmasked_dim:] *= frame_mask
|
||||||
|
|
||||||
# deal with layer dropout.
|
# deal with layer dropout.
|
||||||
layerdrop_mask, layerdrop_scales = self.get_layerdrop_info(batch_size=src.shape[1])
|
layerdrop_mask, layerdrop_scales = self.get_layerdrop_info()
|
||||||
|
|
||||||
|
|
||||||
src = src * feature_mask
|
output = output * feature_mask
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
|
|
||||||
output, attn_scores = mod(
|
if layerdrop_mask[i] != 0.0:
|
||||||
output,
|
output, attn_scores = mod(
|
||||||
pos_emb,
|
output,
|
||||||
attn_scores,
|
pos_emb,
|
||||||
src_mask=mask,
|
attn_scores,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_mask=mask,
|
||||||
warmup=warmup,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
layerdrop_mask=layerdrop_mask[i].tolist(), # [ 1.0, 1.0 ], [0.0, 1.0] or [1.0, 0.0]
|
warmup=warmup,
|
||||||
layerdrop_scales=layerdrop_scales[i], # tensor of scales of shape (batch_size, 1)
|
layerdrop_scale=layerdrop_scales[i],
|
||||||
)
|
)
|
||||||
output = output * feature_mask
|
output = output * feature_mask
|
||||||
if i in self.aux_layers:
|
|
||||||
outputs.append(output)
|
|
||||||
|
|
||||||
output = self.combiner(outputs)
|
|
||||||
|
|
||||||
output = output + 0.0 * attn_scores.sum() # just ensure attn_scores is used in backprop
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@ -924,7 +924,8 @@ def run(rank, world_size, args):
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank],
|
||||||
|
find_unused_parameters=True)
|
||||||
|
|
||||||
optimizer = ScaledAdam(model.parameters(),
|
optimizer = ScaledAdam(model.parameters(),
|
||||||
lr=params.initial_lr,
|
lr=params.initial_lr,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user