Have warmup that gradually removes dropout from layers; multiply initialization scales by 0.1.
This commit is contained in:
parent
300da1306d
commit
fe4a7e904f
@ -64,7 +64,7 @@ class Conformer(EncoderInterface):
|
|||||||
num_encoder_layers: Tuple[int] = (12, 12),
|
num_encoder_layers: Tuple[int] = (12, 12),
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
cnn_module_kernel: Tuple[int] = (31, 31),
|
cnn_module_kernel: Tuple[int] = (31, 31),
|
||||||
warmup_batches: float = 6000.0,
|
warmup_batches: float = 4000.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__()
|
super(Conformer, self).__init__()
|
||||||
|
|
||||||
@ -96,11 +96,14 @@ class Conformer(EncoderInterface):
|
|||||||
cnn_module_kernel[0],
|
cnn_module_kernel[0],
|
||||||
|
|
||||||
)
|
)
|
||||||
|
# for the first third of the warmup period, we let the Conv2dSubsampling
|
||||||
|
# layer learn something
|
||||||
self.encoder1 = ConformerEncoder(
|
self.encoder1 = ConformerEncoder(
|
||||||
encoder_layer1,
|
encoder_layer1,
|
||||||
num_encoder_layers[0],
|
num_encoder_layers[0],
|
||||||
dropout,
|
dropout,
|
||||||
warmup_batches,
|
warmup_begin=warmup_batches / 3,
|
||||||
|
warmup_end=2 * warmup_batches / 3,
|
||||||
)
|
)
|
||||||
encoder_layer2 = ConformerEncoderLayer(
|
encoder_layer2 = ConformerEncoderLayer(
|
||||||
d_model[1],
|
d_model[1],
|
||||||
@ -108,13 +111,15 @@ class Conformer(EncoderInterface):
|
|||||||
feedforward_dim[1],
|
feedforward_dim[1],
|
||||||
dropout,
|
dropout,
|
||||||
cnn_module_kernel[1],
|
cnn_module_kernel[1],
|
||||||
|
|
||||||
)
|
)
|
||||||
self.encoder2 = DownsampledConformerEncoder(
|
self.encoder2 = DownsampledConformerEncoder(
|
||||||
ConformerEncoder(
|
ConformerEncoder(
|
||||||
encoder_layer2,
|
encoder_layer2,
|
||||||
num_encoder_layers[1],
|
num_encoder_layers[1],
|
||||||
dropout,
|
dropout,
|
||||||
warmup_batches,
|
warmup_begin=2 * warmup_batches / 3,
|
||||||
|
warmup_end=warmup_batches,
|
||||||
),
|
),
|
||||||
input_dim=d_model[0],
|
input_dim=d_model[0],
|
||||||
output_dim=d_model[1],
|
output_dim=d_model[1],
|
||||||
@ -256,6 +261,11 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
|
|
||||||
|
# we'll overwrite these warmup_begin and warmup_end values from init of
|
||||||
|
# class ConformerEncoder.
|
||||||
|
self.warmup_begin = 0.0
|
||||||
|
self.warmup_end = 1000.0
|
||||||
|
|
||||||
self.self_attn = RelPositionMultiheadAttention(
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
d_model, nhead, dropout=dropout,
|
d_model, nhead, dropout=dropout,
|
||||||
)
|
)
|
||||||
@ -267,7 +277,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(feedforward_dim, d_model,
|
ScaledLinear(feedforward_dim, d_model,
|
||||||
initial_scale=0.1),
|
initial_scale=0.01),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
@ -277,7 +287,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(feedforward_dim, d_model,
|
ScaledLinear(feedforward_dim, d_model,
|
||||||
initial_scale=0.1),
|
initial_scale=0.01),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model,
|
self.conv_module = ConvolutionModule(d_model,
|
||||||
@ -293,6 +303,18 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
max_var_per_eig=0.2,
|
max_var_per_eig=0.2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_warmup_value(self, warmup_count: float) -> float:
|
||||||
|
"""
|
||||||
|
Returns a value that is 0 at the start of training and increases to 1.0 during
|
||||||
|
a warmup period specified during model initialization.
|
||||||
|
"""
|
||||||
|
if warmup_count < self.warmup_begin:
|
||||||
|
return 0.0
|
||||||
|
elif warmup_count > self.warmup_end:
|
||||||
|
return 1.0
|
||||||
|
else:
|
||||||
|
return (warmup_count - self.warmup_begin) / (self.warmup_end - self.warmup_begin)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
@ -300,6 +322,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
attn_scores_in: Optional[Tensor] = None,
|
attn_scores_in: Optional[Tensor] = None,
|
||||||
src_mask: Optional[Tensor] = None,
|
src_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
warmup_count: float = 1.0e+10,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Pass the input through the encoder layer.
|
Pass the input through the encoder layer.
|
||||||
@ -344,6 +367,14 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
|
warmup_value = self.get_warmup_value(warmup_count)
|
||||||
|
if warmup_value < 1.0 and self.training:
|
||||||
|
delta = torch.nn.functional.dropout(src_orig - src,
|
||||||
|
p=0.5 * (1. - warmup_value),
|
||||||
|
training=self.training)
|
||||||
|
src = src_orig + delta
|
||||||
|
|
||||||
|
|
||||||
return src, attn_scores_out
|
return src, attn_scores_out
|
||||||
|
|
||||||
|
|
||||||
@ -359,25 +390,20 @@ class ConformerEncoder(nn.Module):
|
|||||||
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
|
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
|
||||||
>>> src = torch.rand(10, 32, 512)
|
>>> src = torch.rand(10, 32, 512)
|
||||||
>>> out = conformer_encoder(src)
|
>>> out = conformer_encoder(src)
|
||||||
|
|
||||||
|
|
||||||
Returns: (combined_output, output),
|
|
||||||
where `combined_output` has gone through the RandomCombiner module and `output` is just the
|
|
||||||
original output, in case you need to bypass the RandomCombiner module.
|
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder_layer: nn.Module,
|
encoder_layer: nn.Module,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
dropout: float,
|
dropout: float,
|
||||||
warmup_batches: float,
|
warmup_begin: float,
|
||||||
|
warmup_end: float
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# keep track of how many times forward() has been called, for purposes of
|
# keep track of how many times forward() has been called, for purposes of
|
||||||
# 'warmup'
|
# warmup
|
||||||
self.register_buffer('count', torch.tensor(0, dtype=torch.int64))
|
self.register_buffer('warmup_count', torch.tensor(0.0))
|
||||||
self.warmup_batches = warmup_batches
|
|
||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model,
|
self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model,
|
||||||
dropout)
|
dropout)
|
||||||
@ -387,22 +413,26 @@ class ConformerEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
assert 0 <= warmup_begin <= warmup_end
|
||||||
|
|
||||||
num_channels = encoder_layer.norm_final.num_channels
|
|
||||||
|
delta = (1. / num_layers) * (warmup_end - warmup_begin)
|
||||||
|
cur_begin = warmup_begin
|
||||||
|
for i in range(num_layers):
|
||||||
|
self.layers[i].warmup_begin = cur_begin
|
||||||
|
cur_begin += delta
|
||||||
|
self.layers[i].warmup_end = cur_begin
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_warmup_value(self) -> float:
|
def get_warmup_count(self) -> float:
|
||||||
"""
|
"""
|
||||||
Returns a value that is 0 at the start of training and approaches 1.0 after a number of
|
Returns a value that reflects how many times this function has been called in training mode.
|
||||||
'warmup' batches, specified in the constructor.
|
|
||||||
"""
|
"""
|
||||||
batch = self.count.item()
|
ans = self.warmup_count.item()
|
||||||
if self.training:
|
if self.training:
|
||||||
self.count += 1
|
self.warmup_count += 1
|
||||||
return min(1.0, batch / self.warmup_batches)
|
return ans
|
||||||
else:
|
|
||||||
return 1.0 # this is mostly a workaround for an issue with moderl averaging.
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -411,7 +441,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
feature_mask: Union[Tensor, float] = 1.0,
|
feature_mask: Union[Tensor, float] = 1.0,
|
||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tensor:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -430,6 +460,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
|
|
||||||
Returns: (x, x_no_combine), both of shape (S, N, E)
|
Returns: (x, x_no_combine), both of shape (S, N, E)
|
||||||
"""
|
"""
|
||||||
|
warmup_count = self.get_warmup_count() # reflects number of training batches.
|
||||||
pos_emb = self.encoder_pos(src)
|
pos_emb = self.encoder_pos(src)
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
@ -438,50 +469,19 @@ class ConformerEncoder(nn.Module):
|
|||||||
|
|
||||||
output = output * feature_mask
|
output = output * feature_mask
|
||||||
|
|
||||||
outputs = [ output ]
|
|
||||||
|
|
||||||
# warmup starts at 0 at the beginning of training, reaches 1 at a few
|
|
||||||
# thousand minibatches, and then stays there.
|
|
||||||
warmup = self.get_warmup_value()
|
|
||||||
|
|
||||||
def apply_bypass(prev_output: Tensor, output: Tensor,
|
|
||||||
warmup: float,
|
|
||||||
min_output_scale: float = 0.1,
|
|
||||||
max_output_scale: float = 1.0):
|
|
||||||
layer_dropout_prob = 0.075
|
|
||||||
if self.training and random.random() < layer_dropout_prob:
|
|
||||||
output_scale = 0.1
|
|
||||||
else:
|
|
||||||
output_scale = max(warmup * max_output_scale,
|
|
||||||
min_output_scale)
|
|
||||||
if output_scale == 1.0:
|
|
||||||
return output
|
|
||||||
else:
|
|
||||||
return output_scale * output + (1.0 - output_scale) * prev_output
|
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
output, attn_scores = mod(
|
next_output, attn_scores = mod(
|
||||||
outputs[-1],
|
output,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
attn_scores,
|
attn_scores,
|
||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
warmup_count=warmup_count,
|
||||||
)
|
)
|
||||||
# bypass this layer; the scale on `output` reaches a maximum of 0.5 which
|
# this seemed to be helpful...
|
||||||
# empirically seemed slightly better than 1.
|
output = 0.5 * (next_output + output)
|
||||||
output = apply_bypass(outputs[-1], output,
|
|
||||||
warmup, 0.1, 0.5)
|
|
||||||
# also apply bypass to twos and fours of layers.
|
|
||||||
#if i > 0 and i % 2 == 0:
|
|
||||||
# output = apply_bypass(outputs[-2], output,
|
|
||||||
# warmup, 0.25, 1.0)
|
|
||||||
#if i > 0 and i % 4 == 0:
|
|
||||||
# output = apply_bypass(outputs[-4], output,
|
|
||||||
# warmup, 0.25, 1.0)
|
|
||||||
output = output * feature_mask
|
|
||||||
outputs.append(output)
|
|
||||||
|
|
||||||
return outputs[-1]
|
return output
|
||||||
|
|
||||||
|
|
||||||
class DownsampledConformerEncoder(nn.Module):
|
class DownsampledConformerEncoder(nn.Module):
|
||||||
@ -788,7 +788,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
channel_dim=-1, max_abs=10.0,
|
channel_dim=-1, max_abs=10.0,
|
||||||
min_positive=0.0, max_positive=1.0)
|
min_positive=0.0, max_positive=1.0)
|
||||||
self.out_proj = ScaledLinear(
|
self.out_proj = ScaledLinear(
|
||||||
embed_dim // 2, embed_dim, bias=True, initial_scale=0.5
|
embed_dim // 2, embed_dim, bias=True, initial_scale=0.05
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
|
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
|
||||||
@ -1206,7 +1206,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
initial_scale=0.5,
|
initial_scale=0.05,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user