mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make bypass_scale be a tensor.
This commit is contained in:
parent
ff6431ed0f
commit
a680c7de2e
@ -367,8 +367,8 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# to work correctly.
|
# to work correctly.
|
||||||
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05), default=0),
|
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05), default=0),
|
||||||
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0), default=0),
|
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0), default=0),
|
||||||
bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
|
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
|
||||||
bypass_clamp_max: FloatLike = 1.0,
|
bypass_max: FloatLike = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ZipformerEncoderLayer, self).__init__()
|
super(ZipformerEncoderLayer, self).__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -379,8 +379,8 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob)
|
self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob)
|
||||||
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
||||||
# ever becoming zero.
|
# ever becoming zero.
|
||||||
self.bypass_clamp_min = copy.deepcopy(bypass_clamp_min)
|
self.bypass_min = copy.deepcopy(bypass_min)
|
||||||
self.bypass_clamp_max = copy.deepcopy(bypass_clamp_max)
|
self.bypass_max = copy.deepcopy(bypass_max)
|
||||||
|
|
||||||
|
|
||||||
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
|
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
|
||||||
@ -421,7 +421,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.norm_final = BasicNorm(embed_dim)
|
self.norm_final = BasicNorm(embed_dim)
|
||||||
|
|
||||||
self.bypass_scale = nn.Parameter(torch.tensor(0.5))
|
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||||
|
|
||||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||||
self.balancer = ActivationBalancer(
|
self.balancer = ActivationBalancer(
|
||||||
@ -439,8 +439,8 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
return self.bypass_scale
|
return self.bypass_scale
|
||||||
else:
|
else:
|
||||||
return limit_param_value(self.bypass_scale,
|
return limit_param_value(self.bypass_scale,
|
||||||
min=float(self.bypass_clamp_min),
|
min=float(self.bypass_min),
|
||||||
max=float(self.bypass_clamp_max))
|
max=float(self.bypass_max))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user