From b091ae54826a87cb774d563e74a7c9a62ea0e2aa Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 31 Oct 2022 17:10:28 +0800 Subject: [PATCH] Add bias in weight module --- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 4019b6358..e0d48b0af 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -827,6 +827,7 @@ class SimpleCombiner(torch.nn.Module): assert dim2 >= dim1 self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01) self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01) + self.bias = nn.Parameter(torch.zeros(())) self.min_weight = min_weight def forward(self, @@ -844,7 +845,7 @@ class SimpleCombiner(torch.nn.Module): weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True) weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True) - logit = (weight1 + weight2) + logit = (weight1 + weight2 + self.bias) if self.training and random.random() < 0.1: logit = penalize_abs_values_gt(logit,