mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Start whitening schedules for activation in NonlinAttentionModule, AttentionSqueezeModule lower; increase some whitening probs.
This commit is contained in:
parent
0ac26f4234
commit
dd3826104e
@ -336,9 +336,9 @@ class Zipformer(EncoderInterface):
|
||||
return x, lengths
|
||||
|
||||
|
||||
def _whitening_schedule(x: float) -> ScheduledFloat:
|
||||
def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
|
||||
return ScheduledFloat((0.0, x),
|
||||
(12000.0, 2.0 * x),
|
||||
(12000.0, ratio * x),
|
||||
default=x)
|
||||
|
||||
class ZipformerEncoderLayer(nn.Module):
|
||||
@ -429,7 +429,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
max_abs=6.0,
|
||||
)
|
||||
self.whiten = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(4.0),
|
||||
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
@ -1330,7 +1330,7 @@ class AttentionSqueeze(nn.Module):
|
||||
min_prob=0.05,
|
||||
)
|
||||
self.activation_whiten = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
@ -1438,12 +1438,12 @@ class NonlinAttentionModule(nn.Module):
|
||||
initial_scale=0.05)
|
||||
|
||||
self.whiten1 = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
prob=(0.01, 0.1),
|
||||
whitening_limit=_whitening_schedule(5.0),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
self.whiten2 = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
prob=(0.01, 0.1),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user