Apply layer bypass during warmup in a new way, including 2s and 4s of layers.

This commit is contained in:
Daniel Povey 2022-10-07 16:56:40 +08:00
parent 314f2381e2
commit ebf8aa129d

View File

@ -49,6 +49,8 @@ class Conformer(EncoderInterface):
layer_dropout (float): layer-dropout rate. layer_dropout (float): layer-dropout rate.
cnn_module_kernel (int): Kernel size of convolution module cnn_module_kernel (int): Kernel size of convolution module
vgg_frontend (bool): whether to use vgg frontend. vgg_frontend (bool): whether to use vgg frontend.
warmup (float): number of batches to warm up over (gradually skip
layer bypass)
""" """
def __init__( def __init__(
@ -63,6 +65,7 @@ class Conformer(EncoderInterface):
num_encoder_layers: Tuple[int] = (12, 12), num_encoder_layers: Tuple[int] = (12, 12),
dropout: float = 0.1, dropout: float = 0.1,
cnn_module_kernel: Tuple[int] = (31, 31), cnn_module_kernel: Tuple[int] = (31, 31),
warmup_batches: float = 3000,
) -> None: ) -> None:
super(Conformer, self).__init__() super(Conformer, self).__init__()
@ -98,6 +101,7 @@ class Conformer(EncoderInterface):
encoder_layer1, encoder_layer1,
num_encoder_layers[0], num_encoder_layers[0],
dropout, dropout,
warmup_batches,
) )
encoder_layer2 = ConformerEncoderLayer( encoder_layer2 = ConformerEncoderLayer(
d_model[1], d_model[1],
@ -111,6 +115,7 @@ class Conformer(EncoderInterface):
encoder_layer2, encoder_layer2,
num_encoder_layers[1], num_encoder_layers[1],
dropout, dropout,
warmup_batches,
), ),
input_dim=d_model[0], input_dim=d_model[0],
output_dim=d_model[1], output_dim=d_model[1],
@ -362,13 +367,19 @@ class ConformerEncoder(nn.Module):
original output, in case you need to bypass the RandomCombiner module. original output, in case you need to bypass the RandomCombiner module.
""" """
def __init__( def __init__(
self, self,
encoder_layer: nn.Module, encoder_layer: nn.Module,
num_layers: int, num_layers: int,
dropout: float dropout: float,
warmup_batches: float,
) -> None: ) -> None:
super().__init__() super().__init__()
# keep track of how many times forward() has been called, for purposes of
# 'warmup'
self.register_buffer('count', torch.tensor(0, dtype=torch.int64))
self.warmup_batches = warmup_batches
self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model,
dropout) dropout)
@ -381,6 +392,18 @@ class ConformerEncoder(nn.Module):
num_channels = encoder_layer.norm_final.num_channels num_channels = encoder_layer.norm_final.num_channels
def get_warmup_value(self) -> float:
"""
Returns a value that is 0 at the start of training and approaches 1.0 after a number of
'warmup' batches, specified in the constructor.
"""
batch = self.count.item()
if self.training:
self.count += 1
return min(1.0, batch / self.warmup_batches)
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
@ -414,30 +437,46 @@ class ConformerEncoder(nn.Module):
output = output * feature_mask output = output * feature_mask
num_layers = len(self.layers) outputs = [ output ]
indexes = list(range(num_layers))
if self.training:
num_to_swap = int((num_layers - 1) * 0.1)
for i in range(num_to_swap):
j = random.randint(0, num_layers - 2)
a,b = indexes[j], indexes[j+1]
indexes[j] = b
indexes[j+1] = a
# warmup starts at 0 at the beginning of training, reaches 1 at a few
# thousand minibatches, and then stays there.
warmup = self.get_warmup_value()
for i in indexes: def apply_bypass(prev_output: Tensor, output: Tensor,
mod = self.layers[i] warmup: float,
min_output_scale: float = 0.1,
max_output_scale: float = 1.0):
output_scale = max(warmup * max_output_scale,
min_output_scale)
if output_scale == 1.0:
return output
else:
return output_scale * output + (1.0 - output_scale) * prev_output
for i, mod in enumerate(self.layers):
output, attn_scores = mod( output, attn_scores = mod(
output, outputs[-1],
pos_emb, pos_emb,
attn_scores, attn_scores,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
) )
# bypass this layer; the scale on `output` reaches a maximum of 0.5 which
# empirically seemed slightly better than 1.
output = apply_bypass(outputs[-1], output,
warmup, 0.1, 0.5)
# also apply bypass to twos and fours of layers.
if i > 0 and i % 2 == 0:
output = apply_bypass(outputs[-2], output,
warmup, 0.25, 1.0)
if i > 0 and i % 4 == 0:
output = apply_bypass(outputs[-4], output,
warmup, 0.25, 1.0)
output = output * feature_mask output = output * feature_mask
outputs.append(output)
return output return outputs[-1]
class DownsampledConformerEncoder(nn.Module): class DownsampledConformerEncoder(nn.Module):