From 97a0fbe44b984a876a84ee1cf3557a1543c093b4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 8 Oct 2022 20:32:49 +0800 Subject: [PATCH] Make the bypass scale trainable. --- .../ASR/pruned_transducer_stateless7/conformer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 210f89cc6..222946f2b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -294,6 +294,8 @@ class ConformerEncoderLayer(nn.Module): self.norm_final = BasicNorm(d_model) + self.bypass_scale = nn.Parameter(torch.tensor(0.5)) + # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( d_model, channel_dim=-1, @@ -367,13 +369,13 @@ class ConformerEncoderLayer(nn.Module): src = self.norm_final(self.balancer(src)) warmup_value = self.get_warmup_value(warmup_count) + + delta = src - src_orig if warmup_value < 1.0 and self.training: - delta = src - src_orig keep_prob = 0.25 + 0.75 * warmup_value # the :1 means the mask is chosen per frame. delta = delta * (torch.rand_like(delta[...,:1]) < keep_prob) - src = src_orig + delta - + src = src_orig + delta * self.bypass_scale return src, attn_scores_out @@ -514,8 +516,6 @@ class ConformerEncoder(nn.Module): src_key_padding_mask=src_key_padding_mask, warmup_count=warmup_count, ) - # this seemed to be helpful... - output = 0.5 * (next_output + output) output = output * feature_mask