mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make bypass_scale be a tensor.
This commit is contained in:
parent
ff6431ed0f
commit
a680c7de2e
@ -367,8 +367,8 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
# to work correctly.
|
||||
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05), default=0),
|
||||
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0), default=0),
|
||||
bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
|
||||
bypass_clamp_max: FloatLike = 1.0,
|
||||
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
|
||||
bypass_max: FloatLike = 1.0,
|
||||
) -> None:
|
||||
super(ZipformerEncoderLayer, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@ -379,8 +379,8 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob)
|
||||
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
||||
# ever becoming zero.
|
||||
self.bypass_clamp_min = copy.deepcopy(bypass_clamp_min)
|
||||
self.bypass_clamp_max = copy.deepcopy(bypass_clamp_max)
|
||||
self.bypass_min = copy.deepcopy(bypass_min)
|
||||
self.bypass_max = copy.deepcopy(bypass_max)
|
||||
|
||||
|
||||
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
|
||||
@ -421,7 +421,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
self.norm_final = BasicNorm(embed_dim)
|
||||
|
||||
self.bypass_scale = nn.Parameter(torch.tensor(0.5))
|
||||
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||
|
||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||
self.balancer = ActivationBalancer(
|
||||
@ -439,8 +439,8 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
return self.bypass_scale
|
||||
else:
|
||||
return limit_param_value(self.bypass_scale,
|
||||
min=float(self.bypass_clamp_min),
|
||||
max=float(self.bypass_clamp_max))
|
||||
min=float(self.bypass_min),
|
||||
max=float(self.bypass_max))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user