diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 33faeb4a3..4ce8f5e30 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -255,8 +255,7 @@ class Zipformer2(EncoderInterface): logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.") skip_layers.append(j) - skip_modules.append(SimpleCombiner(self.encoder_dim[i-1], - min_weight=(0.0, 0.25))) + skip_modules.append(BypassModule(self.encoder_dim[i])) break self.skip_layers = skip_layers self.skip_modules = nn.ModuleList(skip_modules) @@ -389,22 +388,25 @@ class Zipformer2(EncoderInterface): for i, module in enumerate(self.encoders): ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + if self.skip_layers[i] is not None: # this how we implement U-net-like skipping of some series of # stacks. The layer_skip_dropout_prob is to discourage it from # completely ignoring the middle layers, especially early in # training, - batch_size = x.shape[1] - skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x) + skip_output = convert_num_channels(outputs[self.skip_layers[i]], + self.encoder_dim[i]) + skip_x = self.skip_modules[i](skip_output, x) layer_skip_dropout_prob = float(self.layer_skip_dropout_prob) if self.training and layer_skip_dropout_prob > 0: + batch_size = x.shape[1] mask = (torch.rand((1, batch_size, 1), device=x.device) > layer_skip_dropout_prob) x = torch.where(mask, skip_x, x) else: x = skip_x - x = convert_num_channels(x, self.encoder_dim[i]) x = module(x, chunk_size=chunk_size, feature_mask=feature_masks[i], @@ -524,14 +526,15 @@ class Zipformer2EncoderLayer(nn.Module): const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), + bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim # self.bypass implements layer skipping as well as bypass; see its default values. - self.bypass = BypassModule(embed_dim) + self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate) # bypass_mid is bypass used in the middle of the layer. - self.bypass_mid = BypassModule(embed_dim, skip_rate=0.0) + self.bypass_mid = BypassModule(embed_dim) # skip probability for dynamic modules (meaning: anything but feedforward). @@ -872,7 +875,7 @@ class BypassModule(nn.Module): def __init__( self, embed_dim: int, - skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0), + skip_rate: FloatLike = 0.0, scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), scale_max: FloatLike = 1.0): super().__init__() @@ -931,7 +934,7 @@ class DownsampledZipformer2Encoder(nn.Module): downsample, dropout) self.encoder = encoder self.upsample = SimpleUpsample(dim, downsample) - self.out_combiner = SimpleCombiner(dim, min_weight=(0.0, 0.25)) + self.out_combiner = BypassModule(dim) def forward(self, @@ -1048,46 +1051,6 @@ class SimpleUpsample(torch.nn.Module): src = src.reshape(seq_len * upsample, batch_size, num_channels) return src -class SimpleCombiner(torch.nn.Module): - """ - A very simple way of combining 2 vectors of 2 different dims, via a - learned weighted combination in the shared part of the dim. - Args: - dim1: the dimension of the first input, e.g. 256 - dim2: the dimension of the second input, e.g. 384. - The output will have the same dimension as dim2. - """ - def __init__(self, - dim: int, - min_weight: Tuple[float, float] = (0., 0.)): - super(SimpleCombiner, self).__init__() - initial_weight1 = 0.1 - self.weight1 = nn.Parameter(torch.full((dim,), initial_weight1)) - self.min_weight = min_weight - - def forward(self, - src1: Tensor, - src2: Tensor) -> Tensor: - """ - src1: (*, other_dim) - src2: (*, dim) - - Returns: a tensor of shape (*, dim) - """ - assert src1.shape[:-1] == src2.shape[:-1] - num_channels = src2.shape[-1] - src1 = convert_num_channels(src1, num_channels) - - - weight1 = limit_param_value(self.weight1, - min=self.min_weight[0], - max=1.0-self.min_weight[1], - training=self.training) - - return src1 * weight1 + src2 * (1.0 - weight1) - - - class CompactRelPositionalEncoding(torch.nn.Module): """