Add warmup schedule for zipformer encoder layer, from 1.0 -> 0.2.

This commit is contained in:
Daniel Povey 2022-10-30 14:41:18 +08:00
parent 8b0722e626
commit e9c69d8477

View File

@ -272,6 +272,9 @@ class ZipformerEncoderLayer(nn.Module):
self.d_model = d_model self.d_model = d_model
# will be written to, see set_batch_count()
self.batch_count = 0
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(
d_model, attention_dim, nhead, pos_dim, dropout=0.0, d_model, attention_dim, nhead, pos_dim, dropout=0.0,
) )
@ -310,6 +313,25 @@ class ZipformerEncoderLayer(nn.Module):
prob=(0.025, 0.25), prob=(0.025, 0.25),
grad_scale=0.01) grad_scale=0.01)
def get_bypass_scale(self):
if torch.jit.is_scripting() or not self.training:
return self.bypass_scale
if random.random() < 0.1:
# ensure we get grads if self.bypass_scale becomes out of range
return self.bypass_scale
# hardcode warmup period for bypass scale
warmup_period = 4000.0
initial_clamp_min = 1.0
final_clamp_min = 0.2
if self.batch_count > warmup_period:
clamp_min = final_clamp_min
else:
clamp_min = (initial_clamp_min -
(self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min))
return self.bypass_scale.clamp(min=clamp_min, max=1.0)
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
@ -363,14 +385,8 @@ class ZipformerEncoderLayer(nn.Module):
src = self.norm_final(self.balancer(src)) src = self.norm_final(self.balancer(src))
delta = src - src_orig delta = src - src_orig
bypass_scale = self.bypass_scale
if torch.jit.is_scripting() or (not self.training) or random.random() > 0.1: src = src_orig + delta * self.get_bypass_scale()
# with probability 0.9, in training mode, or always, in testing
# mode, clamp bypass_scale to [ 0.1, 1.0 ]; this will encourage it
# to learn parameters within this range by making parameters that
# are outside that range range noisy.
bypass_scale = bypass_scale.clamp(min=0.5, max=1.0)
src = src_orig + delta * bypass_scale
return self.whiten(src) return self.whiten(src)
@ -397,9 +413,9 @@ class ZipformerEncoder(nn.Module):
warmup_end: float warmup_end: float
) -> None: ) -> None:
super().__init__() super().__init__()
# self.batch_count will be written to by the top-level training program. # will be written to, see set_batch_count() Note: in inference time this
# Note: in inference time this may be zero but should be treated as large, # may be zero but should be treated as large, we can check if
# we can check if self.training is true. # self.training is true.
self.batch_count = 0 self.batch_count = 0
self.warmup_begin = warmup_begin self.warmup_begin = warmup_begin
self.warmup_end = warmup_end self.warmup_end = warmup_end