mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add warmup schedule for zipformer encoder layer, from 1.0 -> 0.2.
This commit is contained in:
parent
8b0722e626
commit
e9c69d8477
@ -272,6 +272,9 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
|
|
||||||
|
# will be written to, see set_batch_count()
|
||||||
|
self.batch_count = 0
|
||||||
|
|
||||||
self.self_attn = RelPositionMultiheadAttention(
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
d_model, attention_dim, nhead, pos_dim, dropout=0.0,
|
d_model, attention_dim, nhead, pos_dim, dropout=0.0,
|
||||||
)
|
)
|
||||||
@ -310,6 +313,25 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
|
def get_bypass_scale(self):
|
||||||
|
if torch.jit.is_scripting() or not self.training:
|
||||||
|
return self.bypass_scale
|
||||||
|
if random.random() < 0.1:
|
||||||
|
# ensure we get grads if self.bypass_scale becomes out of range
|
||||||
|
return self.bypass_scale
|
||||||
|
# hardcode warmup period for bypass scale
|
||||||
|
warmup_period = 4000.0
|
||||||
|
initial_clamp_min = 1.0
|
||||||
|
final_clamp_min = 0.2
|
||||||
|
if self.batch_count > warmup_period:
|
||||||
|
clamp_min = final_clamp_min
|
||||||
|
else:
|
||||||
|
clamp_min = (initial_clamp_min -
|
||||||
|
(self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min))
|
||||||
|
return self.bypass_scale.clamp(min=clamp_min, max=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
@ -363,14 +385,8 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
delta = src - src_orig
|
delta = src - src_orig
|
||||||
bypass_scale = self.bypass_scale
|
|
||||||
if torch.jit.is_scripting() or (not self.training) or random.random() > 0.1:
|
src = src_orig + delta * self.get_bypass_scale()
|
||||||
# with probability 0.9, in training mode, or always, in testing
|
|
||||||
# mode, clamp bypass_scale to [ 0.1, 1.0 ]; this will encourage it
|
|
||||||
# to learn parameters within this range by making parameters that
|
|
||||||
# are outside that range range noisy.
|
|
||||||
bypass_scale = bypass_scale.clamp(min=0.5, max=1.0)
|
|
||||||
src = src_orig + delta * bypass_scale
|
|
||||||
|
|
||||||
return self.whiten(src)
|
return self.whiten(src)
|
||||||
|
|
||||||
@ -397,9 +413,9 @@ class ZipformerEncoder(nn.Module):
|
|||||||
warmup_end: float
|
warmup_end: float
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# self.batch_count will be written to by the top-level training program.
|
# will be written to, see set_batch_count() Note: in inference time this
|
||||||
# Note: in inference time this may be zero but should be treated as large,
|
# may be zero but should be treated as large, we can check if
|
||||||
# we can check if self.training is true.
|
# self.training is true.
|
||||||
self.batch_count = 0
|
self.batch_count = 0
|
||||||
self.warmup_begin = warmup_begin
|
self.warmup_begin = warmup_begin
|
||||||
self.warmup_end = warmup_end
|
self.warmup_end = warmup_end
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user