From a680c7de2e5543565f6adfd51d4670a95e9fa2c5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 14 Nov 2022 19:12:16 +0800 Subject: [PATCH] Make bypass_scale be a tensor. --- .../ASR/pruned_transducer_stateless7/zipformer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 232c0d4c8..e686ad0b4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -367,8 +367,8 @@ class ZipformerEncoderLayer(nn.Module): # to work correctly. layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05), default=0), dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0), default=0), - bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0), - bypass_clamp_max: FloatLike = 1.0, + bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0), + bypass_max: FloatLike = 1.0, ) -> None: super(ZipformerEncoderLayer, self).__init__() self.embed_dim = embed_dim @@ -379,8 +379,8 @@ class ZipformerEncoderLayer(nn.Module): self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob) # min and max for self.bypass_scale, applied with probability 0.5 to avoid grads # ever becoming zero. - self.bypass_clamp_min = copy.deepcopy(bypass_clamp_min) - self.bypass_clamp_max = copy.deepcopy(bypass_clamp_max) + self.bypass_min = copy.deepcopy(bypass_min) + self.bypass_max = copy.deepcopy(bypass_max) self.self_attn_weights = RelPositionMultiheadAttentionWeights( @@ -421,7 +421,7 @@ class ZipformerEncoderLayer(nn.Module): self.norm_final = BasicNorm(embed_dim) - self.bypass_scale = nn.Parameter(torch.tensor(0.5)) + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( @@ -439,8 +439,8 @@ class ZipformerEncoderLayer(nn.Module): return self.bypass_scale else: return limit_param_value(self.bypass_scale, - min=float(self.bypass_clamp_min), - max=float(self.bypass_clamp_max)) + min=float(self.bypass_min), + max=float(self.bypass_max)) def forward( self,