mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Apply layer bypass during warmup in a new way, including 2s and 4s of layers.
This commit is contained in:
parent
314f2381e2
commit
ebf8aa129d
@ -49,6 +49,8 @@ class Conformer(EncoderInterface):
|
||||
layer_dropout (float): layer-dropout rate.
|
||||
cnn_module_kernel (int): Kernel size of convolution module
|
||||
vgg_frontend (bool): whether to use vgg frontend.
|
||||
warmup (float): number of batches to warm up over (gradually skip
|
||||
layer bypass)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -63,6 +65,7 @@ class Conformer(EncoderInterface):
|
||||
num_encoder_layers: Tuple[int] = (12, 12),
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: Tuple[int] = (31, 31),
|
||||
warmup_batches: float = 3000,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__()
|
||||
|
||||
@ -98,6 +101,7 @@ class Conformer(EncoderInterface):
|
||||
encoder_layer1,
|
||||
num_encoder_layers[0],
|
||||
dropout,
|
||||
warmup_batches,
|
||||
)
|
||||
encoder_layer2 = ConformerEncoderLayer(
|
||||
d_model[1],
|
||||
@ -111,6 +115,7 @@ class Conformer(EncoderInterface):
|
||||
encoder_layer2,
|
||||
num_encoder_layers[1],
|
||||
dropout,
|
||||
warmup_batches,
|
||||
),
|
||||
input_dim=d_model[0],
|
||||
output_dim=d_model[1],
|
||||
@ -365,10 +370,16 @@ class ConformerEncoder(nn.Module):
|
||||
self,
|
||||
encoder_layer: nn.Module,
|
||||
num_layers: int,
|
||||
dropout: float
|
||||
dropout: float,
|
||||
warmup_batches: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# keep track of how many times forward() has been called, for purposes of
|
||||
# 'warmup'
|
||||
self.register_buffer('count', torch.tensor(0, dtype=torch.int64))
|
||||
self.warmup_batches = warmup_batches
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model,
|
||||
dropout)
|
||||
|
||||
@ -381,6 +392,18 @@ class ConformerEncoder(nn.Module):
|
||||
num_channels = encoder_layer.norm_final.num_channels
|
||||
|
||||
|
||||
|
||||
def get_warmup_value(self) -> float:
|
||||
"""
|
||||
Returns a value that is 0 at the start of training and approaches 1.0 after a number of
|
||||
'warmup' batches, specified in the constructor.
|
||||
"""
|
||||
batch = self.count.item()
|
||||
if self.training:
|
||||
self.count += 1
|
||||
return min(1.0, batch / self.warmup_batches)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
@ -414,30 +437,46 @@ class ConformerEncoder(nn.Module):
|
||||
|
||||
output = output * feature_mask
|
||||
|
||||
num_layers = len(self.layers)
|
||||
indexes = list(range(num_layers))
|
||||
if self.training:
|
||||
num_to_swap = int((num_layers - 1) * 0.1)
|
||||
for i in range(num_to_swap):
|
||||
j = random.randint(0, num_layers - 2)
|
||||
a,b = indexes[j], indexes[j+1]
|
||||
indexes[j] = b
|
||||
indexes[j+1] = a
|
||||
outputs = [ output ]
|
||||
|
||||
# warmup starts at 0 at the beginning of training, reaches 1 at a few
|
||||
# thousand minibatches, and then stays there.
|
||||
warmup = self.get_warmup_value()
|
||||
|
||||
for i in indexes:
|
||||
mod = self.layers[i]
|
||||
def apply_bypass(prev_output: Tensor, output: Tensor,
|
||||
warmup: float,
|
||||
min_output_scale: float = 0.1,
|
||||
max_output_scale: float = 1.0):
|
||||
output_scale = max(warmup * max_output_scale,
|
||||
min_output_scale)
|
||||
if output_scale == 1.0:
|
||||
return output
|
||||
else:
|
||||
return output_scale * output + (1.0 - output_scale) * prev_output
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
output, attn_scores = mod(
|
||||
output,
|
||||
outputs[-1],
|
||||
pos_emb,
|
||||
attn_scores,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
# bypass this layer; the scale on `output` reaches a maximum of 0.5 which
|
||||
# empirically seemed slightly better than 1.
|
||||
output = apply_bypass(outputs[-1], output,
|
||||
warmup, 0.1, 0.5)
|
||||
# also apply bypass to twos and fours of layers.
|
||||
if i > 0 and i % 2 == 0:
|
||||
output = apply_bypass(outputs[-2], output,
|
||||
warmup, 0.25, 1.0)
|
||||
if i > 0 and i % 4 == 0:
|
||||
output = apply_bypass(outputs[-4], output,
|
||||
warmup, 0.25, 1.0)
|
||||
output = output * feature_mask
|
||||
outputs.append(output)
|
||||
|
||||
return output
|
||||
return outputs[-1]
|
||||
|
||||
|
||||
class DownsampledConformerEncoder(nn.Module):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user