diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 9985c9001..4019b6358 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -178,7 +178,8 @@ class Zipformer(EncoderInterface): f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.") skip_layers.append(j) skip_modules.append(SimpleCombiner(self.encoder_dims[j], - self.encoder_dims[i-1])) + self.encoder_dims[i-1], + min_weight=(0.0,0.25))) break self.skip_layers = skip_layers self.skip_modules = nn.ModuleList(skip_modules) @@ -620,7 +621,8 @@ class DownsampledZipformerEncoder(nn.Module): self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) self.out_combiner = SimpleCombiner(input_dim, - output_dim) + output_dim, + min_weight=(0.0, 0.25)) def forward(self, @@ -814,17 +816,18 @@ class SimpleCombiner(torch.nn.Module): learned weighted combination in the shared part of the dim. Args: dim1: the dimension of the first input, e.g. 256 - dim2: the dimension of the second input, e.g. 384. Require dim2 >= dim1. + dim2: the dimension of the second input, e.g. 384. The output will have the same dimension as dim2. """ def __init__(self, dim1: int, - dim2: int): + dim2: int, + min_weight: Tuple[float] = (0., 0.)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1 self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01) self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01) - + self.min_weight = min_weight def forward(self, src1: Tensor, @@ -844,15 +847,36 @@ class SimpleCombiner(torch.nn.Module): logit = (weight1 + weight2) if self.training and random.random() < 0.1: - logit = penalize_abs_values_gt(logit, - limit=25.0, - penalty=1.0e-04) + logit = penalize_abs_values_gt(logit, + limit=25.0, + penalty=1.0e-04) + + # `weight` will be the wight on src1. weight = logit.sigmoid() - src2_part1 = src2[...,:dim1] - part1 = src1 * weight + src2_part1 * (1.0 - weight) - part2 = src2[...,dim1:] - return torch.cat((part1, part2), dim=-1) + if self.training and self.min_weight != (0., 0.) and random.random() < 0.25: + weight = weight.clamp(min=self.min_weight[0], + max=1.0-self.min_weight[1]) + + + src1 = src1 * weight + src2 = src2 * (1.0 - weight) + + src1_dim = src1.shape[-1] + src2_dim = src2.shape[-1] + if src1_dim != src2_dim: + if src1_dim < src2_dim: + zeros_shape = list(src1.shape[:-1]) + [src2_dim - src1_dim] + src1 = torch.cat((src1, torch.zeros(*zeros_shape, + device=src1.device, + dtype=src1.dtype)), + dim=-1) + else: + src1 = src1[:src2_dim] + + + return src1 + src2 +