Make bypass_scale be a tensor.

This commit is contained in:
Daniel Povey 2022-11-14 19:12:16 +08:00
parent ff6431ed0f
commit a680c7de2e

View File

@ -367,8 +367,8 @@ class ZipformerEncoderLayer(nn.Module):
# to work correctly. # to work correctly.
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05), default=0), layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05), default=0),
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0), default=0), dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0), default=0),
bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
bypass_clamp_max: FloatLike = 1.0, bypass_max: FloatLike = 1.0,
) -> None: ) -> None:
super(ZipformerEncoderLayer, self).__init__() super(ZipformerEncoderLayer, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
@ -379,8 +379,8 @@ class ZipformerEncoderLayer(nn.Module):
self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob) self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob)
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads # min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
# ever becoming zero. # ever becoming zero.
self.bypass_clamp_min = copy.deepcopy(bypass_clamp_min) self.bypass_min = copy.deepcopy(bypass_min)
self.bypass_clamp_max = copy.deepcopy(bypass_clamp_max) self.bypass_max = copy.deepcopy(bypass_max)
self.self_attn_weights = RelPositionMultiheadAttentionWeights( self.self_attn_weights = RelPositionMultiheadAttentionWeights(
@ -421,7 +421,7 @@ class ZipformerEncoderLayer(nn.Module):
self.norm_final = BasicNorm(embed_dim) self.norm_final = BasicNorm(embed_dim)
self.bypass_scale = nn.Parameter(torch.tensor(0.5)) self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
# try to ensure the output is close to zero-mean (or at least, zero-median). # try to ensure the output is close to zero-mean (or at least, zero-median).
self.balancer = ActivationBalancer( self.balancer = ActivationBalancer(
@ -439,8 +439,8 @@ class ZipformerEncoderLayer(nn.Module):
return self.bypass_scale return self.bypass_scale
else: else:
return limit_param_value(self.bypass_scale, return limit_param_value(self.bypass_scale,
min=float(self.bypass_clamp_min), min=float(self.bypass_min),
max=float(self.bypass_clamp_max)) max=float(self.bypass_max))
def forward( def forward(
self, self,