Have warmup that gradually removes dropout from layers; multiply initialization scales by 0.1.

This commit is contained in:
Daniel Povey 2022-10-08 12:45:22 +08:00
parent 300da1306d
commit fe4a7e904f

View File

@ -64,7 +64,7 @@ class Conformer(EncoderInterface):
num_encoder_layers: Tuple[int] = (12, 12),
dropout: float = 0.1,
cnn_module_kernel: Tuple[int] = (31, 31),
warmup_batches: float = 6000.0,
warmup_batches: float = 4000.0,
) -> None:
super(Conformer, self).__init__()
@ -96,11 +96,14 @@ class Conformer(EncoderInterface):
cnn_module_kernel[0],
)
# for the first third of the warmup period, we let the Conv2dSubsampling
# layer learn something
self.encoder1 = ConformerEncoder(
encoder_layer1,
num_encoder_layers[0],
dropout,
warmup_batches,
warmup_begin=warmup_batches / 3,
warmup_end=2 * warmup_batches / 3,
)
encoder_layer2 = ConformerEncoderLayer(
d_model[1],
@ -108,13 +111,15 @@ class Conformer(EncoderInterface):
feedforward_dim[1],
dropout,
cnn_module_kernel[1],
)
self.encoder2 = DownsampledConformerEncoder(
ConformerEncoder(
encoder_layer2,
num_encoder_layers[1],
dropout,
warmup_batches,
warmup_begin=2 * warmup_batches / 3,
warmup_end=warmup_batches,
),
input_dim=d_model[0],
output_dim=d_model[1],
@ -256,6 +261,11 @@ class ConformerEncoderLayer(nn.Module):
self.d_model = d_model
# we'll overwrite these warmup_begin and warmup_end values from init of
# class ConformerEncoder.
self.warmup_begin = 0.0
self.warmup_end = 1000.0
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=dropout,
)
@ -267,7 +277,7 @@ class ConformerEncoderLayer(nn.Module):
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(feedforward_dim, d_model,
initial_scale=0.1),
initial_scale=0.01),
)
self.feed_forward_macaron = nn.Sequential(
@ -277,7 +287,7 @@ class ConformerEncoderLayer(nn.Module):
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(feedforward_dim, d_model,
initial_scale=0.1),
initial_scale=0.01),
)
self.conv_module = ConvolutionModule(d_model,
@ -293,6 +303,18 @@ class ConformerEncoderLayer(nn.Module):
max_var_per_eig=0.2,
)
def get_warmup_value(self, warmup_count: float) -> float:
"""
Returns a value that is 0 at the start of training and increases to 1.0 during
a warmup period specified during model initialization.
"""
if warmup_count < self.warmup_begin:
return 0.0
elif warmup_count > self.warmup_end:
return 1.0
else:
return (warmup_count - self.warmup_begin) / (self.warmup_end - self.warmup_begin)
def forward(
self,
src: Tensor,
@ -300,6 +322,7 @@ class ConformerEncoderLayer(nn.Module):
attn_scores_in: Optional[Tensor] = None,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup_count: float = 1.0e+10,
) -> Tuple[Tensor, Tensor]:
"""
Pass the input through the encoder layer.
@ -344,6 +367,14 @@ class ConformerEncoderLayer(nn.Module):
src = self.norm_final(self.balancer(src))
warmup_value = self.get_warmup_value(warmup_count)
if warmup_value < 1.0 and self.training:
delta = torch.nn.functional.dropout(src_orig - src,
p=0.5 * (1. - warmup_value),
training=self.training)
src = src_orig + delta
return src, attn_scores_out
@ -359,25 +390,20 @@ class ConformerEncoder(nn.Module):
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = conformer_encoder(src)
Returns: (combined_output, output),
where `combined_output` has gone through the RandomCombiner module and `output` is just the
original output, in case you need to bypass the RandomCombiner module.
"""
def __init__(
self,
encoder_layer: nn.Module,
num_layers: int,
dropout: float,
warmup_batches: float,
warmup_begin: float,
warmup_end: float
) -> None:
super().__init__()
# keep track of how many times forward() has been called, for purposes of
# 'warmup'
self.register_buffer('count', torch.tensor(0, dtype=torch.int64))
self.warmup_batches = warmup_batches
# warmup
self.register_buffer('warmup_count', torch.tensor(0.0))
self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model,
dropout)
@ -387,22 +413,26 @@ class ConformerEncoder(nn.Module):
)
self.num_layers = num_layers
assert 0 <= warmup_begin <= warmup_end
num_channels = encoder_layer.norm_final.num_channels
delta = (1. / num_layers) * (warmup_end - warmup_begin)
cur_begin = warmup_begin
for i in range(num_layers):
self.layers[i].warmup_begin = cur_begin
cur_begin += delta
self.layers[i].warmup_end = cur_begin
def get_warmup_value(self) -> float:
def get_warmup_count(self) -> float:
"""
Returns a value that is 0 at the start of training and approaches 1.0 after a number of
'warmup' batches, specified in the constructor.
Returns a value that reflects how many times this function has been called in training mode.
"""
batch = self.count.item()
ans = self.warmup_count.item()
if self.training:
self.count += 1
return min(1.0, batch / self.warmup_batches)
else:
return 1.0 # this is mostly a workaround for an issue with moderl averaging.
self.warmup_count += 1
return ans
def forward(
@ -411,7 +441,7 @@ class ConformerEncoder(nn.Module):
feature_mask: Union[Tensor, float] = 1.0,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
@ -430,6 +460,7 @@ class ConformerEncoder(nn.Module):
Returns: (x, x_no_combine), both of shape (S, N, E)
"""
warmup_count = self.get_warmup_count() # reflects number of training batches.
pos_emb = self.encoder_pos(src)
output = src
@ -438,50 +469,19 @@ class ConformerEncoder(nn.Module):
output = output * feature_mask
outputs = [ output ]
# warmup starts at 0 at the beginning of training, reaches 1 at a few
# thousand minibatches, and then stays there.
warmup = self.get_warmup_value()
def apply_bypass(prev_output: Tensor, output: Tensor,
warmup: float,
min_output_scale: float = 0.1,
max_output_scale: float = 1.0):
layer_dropout_prob = 0.075
if self.training and random.random() < layer_dropout_prob:
output_scale = 0.1
else:
output_scale = max(warmup * max_output_scale,
min_output_scale)
if output_scale == 1.0:
return output
else:
return output_scale * output + (1.0 - output_scale) * prev_output
for i, mod in enumerate(self.layers):
output, attn_scores = mod(
outputs[-1],
next_output, attn_scores = mod(
output,
pos_emb,
attn_scores,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup_count=warmup_count,
)
# bypass this layer; the scale on `output` reaches a maximum of 0.5 which
# empirically seemed slightly better than 1.
output = apply_bypass(outputs[-1], output,
warmup, 0.1, 0.5)
# also apply bypass to twos and fours of layers.
#if i > 0 and i % 2 == 0:
# output = apply_bypass(outputs[-2], output,
# warmup, 0.25, 1.0)
#if i > 0 and i % 4 == 0:
# output = apply_bypass(outputs[-4], output,
# warmup, 0.25, 1.0)
output = output * feature_mask
outputs.append(output)
# this seemed to be helpful...
output = 0.5 * (next_output + output)
return outputs[-1]
return output
class DownsampledConformerEncoder(nn.Module):
@ -788,7 +788,7 @@ class RelPositionMultiheadAttention(nn.Module):
channel_dim=-1, max_abs=10.0,
min_positive=0.0, max_positive=1.0)
self.out_proj = ScaledLinear(
embed_dim // 2, embed_dim, bias=True, initial_scale=0.5
embed_dim // 2, embed_dim, bias=True, initial_scale=0.05
)
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
@ -1206,7 +1206,7 @@ class ConvolutionModule(nn.Module):
stride=1,
padding=0,
bias=bias,
initial_scale=0.5,
initial_scale=0.05,
)
def forward(self,