mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make the scaling factors more global and the randomness of dropout more random
This commit is contained in:
parent
96e0d92fb7
commit
a9f950a1f7
@ -60,7 +60,7 @@ class Conformer(EncoderInterface):
|
||||
dim_feedforward: int = 2048,
|
||||
num_encoder_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.333,
|
||||
layer_dropout: float = 0.25,
|
||||
cnn_module_kernel: int = 31,
|
||||
aux_layer_period: int = 3,
|
||||
) -> None:
|
||||
@ -153,7 +153,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
>>> pos_emb = torch.rand(32, 19, 512)
|
||||
>>> out = encoder_layer(src, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
@ -193,12 +192,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
self.conv_module = ConvolutionModule(d_model,
|
||||
cnn_module_kernel)
|
||||
|
||||
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
|
||||
# scale_alpha relates to a scale that can help work around layerdrop during training.
|
||||
self.scale_alpha = torch.nn.Parameter(torch.tensor(0.0))
|
||||
|
||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||
self.balancer = ActivationBalancer(
|
||||
d_model, channel_dim=-1,
|
||||
@ -207,7 +202,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
max_var_per_eig=0.2,
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
@ -216,8 +210,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
batch_split: Optional[bool] = None,
|
||||
layerdrop_indicator: float = 1.0,
|
||||
layerdrop_mask: Optional[List[float]] = None,
|
||||
layerdrop_scales: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
@ -232,12 +226,12 @@ class ConformerEncoderLayer(nn.Module):
|
||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||
bypass layers more frequently.
|
||||
batch_split: if not None, this layer will only be applied to
|
||||
part of the batch. if True we apply it to the first half of the batch
|
||||
elements, otherwise to the second half.
|
||||
layerdrop_indicator: a float. It is supposed to be 1.0 if nothing is dropped out,
|
||||
and 0.0 if something is dropped out. You don't have to set this directly,
|
||||
it is set internally if you provide the batch_split option as non-None.
|
||||
|
||||
layerdrop_mask: if None or [1.0, 1.0] then we do the computation as normal. If
|
||||
[1.0, 0.0] or [0.0, 1.0], we will only do this computation for the first or
|
||||
second half of the batch respectively, and just copy the input for the other
|
||||
half.
|
||||
layerdrop_scales: an optional Tensor of shape (batch_size, 1) that will be used as a scale
|
||||
on the change in the embeddings made by this layer.
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
@ -246,8 +240,9 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
if batch_split is not None:
|
||||
process_first_half = batch_split
|
||||
if layerdrop_mask not in [ None, [1.0, 1.0] ]:
|
||||
assert layerdrop_mask in [ [1.0, 0.0], [0.0, 1.0] ]
|
||||
process_first_half = (layerdrop_mask == [1.0, 0.0])
|
||||
batch_size = src.shape[1]
|
||||
mid = batch_size // 2
|
||||
|
||||
@ -257,19 +252,21 @@ class ConformerEncoderLayer(nn.Module):
|
||||
attn_scores_in = torch.zeros(1, 1, 1, 1, device=src.device, dtype=src.dtype).expand(
|
||||
batch_size, seq_len, seq_len, num_heads)
|
||||
|
||||
|
||||
attn_scores_a, attn_scores_b = attn_scores_in[:mid], attn_scores_in[mid:]
|
||||
src_a, src_b = src[:, :mid], src[:, mid:]
|
||||
key_padding_a, key_padding_b = src_key_padding_mask[:mid], src_key_padding_mask[mid:],
|
||||
layerdrop_scales_a, layerdrop_scales_b = layerdrop_scales[:mid], layerdrop_scales[mid:]
|
||||
|
||||
if process_first_half:
|
||||
src_a, attn_scores_a = self.forward(src_a, pos_emb, attn_scores_a, src_mask,
|
||||
key_padding_a, warmup, batch_split=None,
|
||||
layerdrop_indicator=0.0)
|
||||
key_padding_a, warmup,
|
||||
layerdrop_mask=None,
|
||||
layerdrop_scales=layerdrop_scales_a)
|
||||
else:
|
||||
src_b, attn_scores_b = self.forward(src_b, pos_emb, attn_scores_b, src_mask,
|
||||
key_padding_b, warmup, batch_split=None,
|
||||
layerdrop_indicator=0.0)
|
||||
key_padding_b, warmup,
|
||||
layerdrop_mask=None,
|
||||
layerdrop_scales=layerdrop_scales_b)
|
||||
|
||||
return torch.cat((src_a, src_b), dim=1), torch.cat((attn_scores_a, attn_scores_b), dim=0)
|
||||
|
||||
@ -304,11 +301,11 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
if alpha != 1.0 or layerdrop_indicator != 1.0 or self.training:
|
||||
if alpha != 1.0 or layerdrop_scales is not None:
|
||||
# the if(self.training) part is to ensure we have a derivative for
|
||||
# self.scale_alpha.
|
||||
src_offset = src - src_orig
|
||||
scale = alpha * (1.0 + self.scale_alpha * (1.0 - layerdrop_indicator))
|
||||
scale = alpha * (1.0 if layerdrop_scales is None else layerdrop_scales)
|
||||
src = src_orig + src_offset * scale
|
||||
|
||||
return src, attn_scores_out
|
||||
@ -334,7 +331,7 @@ class ConformerEncoder(nn.Module):
|
||||
encoder_layer: nn.Module,
|
||||
num_layers: int,
|
||||
aux_layers: List[int],
|
||||
layer_dropout: float = 0.333
|
||||
layer_dropout: float = 0.25
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert 0 < layer_dropout < 0.5
|
||||
@ -345,6 +342,9 @@ class ConformerEncoder(nn.Module):
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
|
||||
self.layerdrop_scale_mat = nn.Parameter(0.01 * torch.randn(num_layers, num_layers))
|
||||
|
||||
|
||||
assert num_layers - 1 not in aux_layers
|
||||
self.aux_layers = set(aux_layers + [num_layers - 1])
|
||||
|
||||
@ -355,6 +355,72 @@ class ConformerEncoder(nn.Module):
|
||||
random_prob=0.333,
|
||||
)
|
||||
|
||||
def get_layerdrop_info(self,
|
||||
batch_size: int) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""
|
||||
Gets some random information that dictates layer dropout configuration.
|
||||
Args:
|
||||
batch_size: the number of sequences in the batch
|
||||
Returns:
|
||||
(layerdrop_mask, layerdrop_scales)
|
||||
where:
|
||||
layerdrop_mask is a CPU tensor of shape (num_layers, 2) where the 2 represents
|
||||
two halves of the batch, containing 1.0 for positions to be evaluated and 0.0
|
||||
for positions not to be evaluated. It has constraints: at least one of two
|
||||
halves of each layer must be evaluated, and successive layers of the same half
|
||||
pmust be evaluated.
|
||||
|
||||
layerdrop_scales is a learned Tensor of shape (num_layers, batch_size, 1) of the form:
|
||||
1.0 + [learned matrix * (1.0 - layerdrop_scale)]
|
||||
where layerdrop_scale is 1.0 for layers that computed, for this half, and
|
||||
0.0 for layers not computed. This is intended to learn that layers neighboring
|
||||
layers that were not computed should get a higher scale to "make up" for the missing
|
||||
computation.
|
||||
The reason for the specific functional form is to constrain so that if everything
|
||||
is computed (layerdrop_scale is all 1.0), this is constrained to be 1.0, to avoid
|
||||
introducing redundant degrees of freedom.
|
||||
"""
|
||||
num_layers = self.num_layers
|
||||
|
||||
layerdrop_mask = torch.ones(num_layers, 2, device='cpu')
|
||||
|
||||
if not self.training or batch_size == 1:
|
||||
return layerdrop_mask, None
|
||||
|
||||
halves_to_drop = int(2 * num_layers * self.layer_dropout)
|
||||
for _ in range(halves_to_drop):
|
||||
while True:
|
||||
r = random.randrange(0, 2 * num_layers)
|
||||
i = r // 2
|
||||
j = r % 2
|
||||
if layerdrop_mask[i, j - 1] == 0.0:
|
||||
# This position cannot be set to 0.0 because the other
|
||||
# half of the batch is already 0.0 (not computed). This would lead to
|
||||
# one layer not having a gradient.
|
||||
continue
|
||||
if ((i > 0 and layerdrop_mask[i-1, j] == 0.0) or
|
||||
(i + 1 < num_layers and layerdrop_mask[i+1, j] == 0.0)):
|
||||
# This position cannot be set to False because the preceding
|
||||
# or following position for this same half of the batch is
|
||||
# already set to False
|
||||
continue
|
||||
layerdrop_mask[i, j] = 0.0
|
||||
break
|
||||
|
||||
# layerdrop_scales: currently shape is (2, num_layers)
|
||||
device = self.layerdrop_scale_mat.device
|
||||
layerdrop_scales_tmp = 1.0 + torch.matmul(self.layerdrop_scale_mat,
|
||||
1.0 - layerdrop_mask.to(device))
|
||||
|
||||
layerdrop_scales = torch.empty(num_layers, batch_size, 1, device=device)
|
||||
mid = batch_size // 2
|
||||
|
||||
layerdrop_scales[:, :mid, 0] = layerdrop_scales_tmp[:,0:1] # shape: (num_layers, 1)
|
||||
layerdrop_scales[:, mid:, 0] = layerdrop_scales_tmp[:,1:2] # shape: (num_layers, 1)
|
||||
|
||||
return layerdrop_mask, layerdrop_scales
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
@ -401,30 +467,13 @@ class ConformerEncoder(nn.Module):
|
||||
feature_mask[..., feature_unmasked_dim:] *= frame_mask
|
||||
|
||||
# deal with layer dropout.
|
||||
batch_size = src.shape[1]
|
||||
if not self.training or batch_size == 1:
|
||||
dropped_layer_pairs = set() # empty set.
|
||||
else:
|
||||
num_layer_pairs = len(self.layers) // 2
|
||||
layer_pairs = list(range(num_layer_pairs))
|
||||
random.shuffle(layer_pairs)
|
||||
# the * 2 is because we only drop out one layer from each pair:
|
||||
# half for one half of the batch and the other half for the other.
|
||||
num_dropped_pairs = int(self.layer_dropout * 2 * num_layer_pairs)
|
||||
dropped_layer_pairs = set(layer_pairs[:num_dropped_pairs])
|
||||
layerdrop_mask, layerdrop_scales = self.get_layerdrop_info(batch_size=src.shape[1])
|
||||
|
||||
|
||||
rand_bool = (random.random() < 0.5)
|
||||
|
||||
src = src * feature_mask
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
|
||||
if i // 2 not in dropped_layer_pairs:
|
||||
batch_split = None # no layer dropout
|
||||
else:
|
||||
batch_split = rand_bool if i % 2 == 0 else not rand_bool
|
||||
|
||||
output, attn_scores = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
@ -432,7 +481,8 @@ class ConformerEncoder(nn.Module):
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
batch_split=batch_split,
|
||||
layerdrop_mask=layerdrop_mask[i].tolist(), # [ 1.0, 1.0 ], [0.0, 1.0] or [1.0, 0.0]
|
||||
layerdrop_scales=layerdrop_scales[i], # tensor of scales of shape (batch_size, 1)
|
||||
)
|
||||
output = output * feature_mask
|
||||
if i in self.aux_layers:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user