diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index e0d48b0af..0a761d878 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -825,9 +825,7 @@ class SimpleCombiner(torch.nn.Module): min_weight: Tuple[float] = (0., 0.)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1 - self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01) - self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01) - self.bias = nn.Parameter(torch.zeros(())) + self.weight1 = nn.Parameter(torch.zeros(())) self.min_weight = min_weight def forward(self, @@ -843,25 +841,15 @@ class SimpleCombiner(torch.nn.Module): dim1 = src1.shape[-1] dim2 = src2.shape[-1] - weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True) - weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True) - logit = (weight1 + weight2 + self.bias) - if self.training and random.random() < 0.1: - logit = penalize_abs_values_gt(logit, - limit=25.0, - penalty=1.0e-04) - - # `weight` will be the wight on src1. - weight = logit.sigmoid() - - if self.training and self.min_weight != (0., 0.) and random.random() < 0.25: - weight = weight.clamp(min=self.min_weight[0], - max=1.0-self.min_weight[1]) + weight1 = self.weight1 + if self.training and random.random() < 0.25 and self.min_weight != (0., 0.): + weight1 = weight1.clamp(min=self.min_weight[0], + max=1.0-self.min_weight[1]) - src1 = src1 * weight - src2 = src2 * (1.0 - weight) + src1 = src1 * weight1 + src2 = src2 * (1.0 - weight1) src1_dim = src1.shape[-1] src2_dim = src2.shape[-1]