diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 8e4733bfc..e8df40b8e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -138,7 +138,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--zipformer-downsampling-factors", type=str, - default="1,2,4", + default="2,4,8", help="Downsampling factor for each stack of encoder layers.", ) @@ -428,7 +428,7 @@ def get_params() -> AttributeDict: "valid_interval": 3000, # For the 100h subset, use 800 # parameters for zipformer "feature_dim": 80, - "subsampling_factor": 4, + "subsampling_factor": 4, # not passed in, this is fixed. "warm_step": 2000, "env_info": get_env_info(), } @@ -443,7 +443,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: return tuple(map(int, s.split(','))) encoder = Zipformer( num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, + output_downsampling_factor=2, zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors), encoder_dims=to_int_tuple(params.encoder_dims), attention_dim=to_int_tuple(params.attention_dims), diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 01e2b0b1c..cfc15ab94 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -47,7 +47,6 @@ class Zipformer(EncoderInterface): """ Args: num_features (int): Number of input features - subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model: (int,int): embedding dimension of 2 encoder stacks attention_dim: (int,int): attention dimension of 2 encoder stacks nhead (int, int): number of heads @@ -62,12 +61,11 @@ class Zipformer(EncoderInterface): def __init__( self, num_features: int, - subsampling_factor: int = 4, - zipformer_subsampling_factor: int = 4, + output_downsampling_factor: int = 2, encoder_dims: Tuple[int] = (384, 384), attention_dim: Tuple[int] = (256, 256), encoder_unmasked_dims: Tuple[int] = (256, 256), - zipformer_downsampling_factors: Tuple[int] = (1, 2), + zipformer_downsampling_factors: Tuple[int] = (2, 4), nhead: Tuple[int] = (8, 8), feedforward_dim: Tuple[int] = (1536, 2048), num_encoder_layers: Tuple[int] = (12, 12), @@ -78,23 +76,20 @@ class Zipformer(EncoderInterface): super(Zipformer, self).__init__() self.num_features = num_features - self.subsampling_factor = subsampling_factor self.encoder_unmasked_dims = encoder_unmasked_dims assert 0 < encoder_dims[0] <= encoder_dims[1] self.encoder_dims = encoder_dims self.encoder_unmasked_dims = encoder_unmasked_dims self.zipformer_downsampling_factors = zipformer_downsampling_factors + self.output_downsampling_factor = output_downsampling_factor for u,d in zip(encoder_unmasked_dims, encoder_dims): assert u <= d - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, encoder_dims). # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor + # (1) subsampling: T -> T//2 # (2) embedding: num_features -> encoder_dims self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0], dropout=dropout) @@ -125,10 +120,9 @@ class Zipformer(EncoderInterface): ) if zipformer_downsampling_factors[i] != 1: - assert i > 0, "First zipformer layer cannot use downsampling" encoder = DownsampledZipformerEncoder( encoder, - input_dim=encoder_dims[i-1], + input_dim=encoder_dims[i-1] if i > 0 else encoder_dims[0], output_dim=encoder_dims[i], downsample=zipformer_downsampling_factors[i], ) @@ -136,6 +130,10 @@ class Zipformer(EncoderInterface): self.encoders = nn.ModuleList(encoders) + self.downsample_output = AttentionDownsample(encoder_dims[-1], + encoder_dims[-1], + downsample=output_downsampling_factor) + def get_feature_masks( self, x: torch.Tensor) -> List[Union[float, Tensor]]: @@ -216,8 +214,7 @@ class Zipformer(EncoderInterface): with warnings.catch_warnings(): warnings.simplefilter("ignore") - # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 + lengths = (x_lens - 7) // 2 assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) @@ -229,6 +226,10 @@ class Zipformer(EncoderInterface): feature_mask=feature_masks[i], src_key_padding_mask=None if mask is None else mask[...,::ds]) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + lengths = (x_lens + 1) // 2 + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -1468,7 +1469,7 @@ class Conv2dSubsampling(nn.Module): Convert an input of shape (N, T, idim) to an output with shape (N, T', odim), where - T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + T' = (T-3)//2 - 2 == (T-7)//2 It is based on https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa @@ -1489,7 +1490,7 @@ class Conv2dSubsampling(nn.Module): Number of channels in. The input shape is (N, T, in_channels). Caution: It requires: T >=7, in_channels >=7 out_channels - Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) + Output dim. The output shape is (N, (T-3)//2, out_channels) layer1_channels: Number of channels in layer1 layer1_channels: @@ -1503,7 +1504,7 @@ class Conv2dSubsampling(nn.Module): in_channels=1, out_channels=layer1_channels, kernel_size=3, - padding=1, + padding=(0, 1), # (time, freq) ), ActivationBalancer(layer1_channels, channel_dim=1), @@ -1513,6 +1514,7 @@ class Conv2dSubsampling(nn.Module): out_channels=layer2_channels, kernel_size=3, stride=2, + padding=0, ), ActivationBalancer(layer2_channels, channel_dim=1), @@ -1521,13 +1523,13 @@ class Conv2dSubsampling(nn.Module): in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, - stride=2, + stride=(1, 2), # (time, freq) ), ActivationBalancer(layer3_channels, channel_dim=1), DoubleSwish(), ) - out_height = (((in_channels - 1) // 2 - 1) // 2) + out_height = (((in_channels - 1) // 2) - 1) // 2 self.out = ScaledLinear(out_height * layer3_channels, out_channels) self.dropout = nn.Dropout(dropout) @@ -1545,7 +1547,7 @@ class Conv2dSubsampling(nn.Module): # On entry, x is (N, T, idim) x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = self.conv(x) - # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) + # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).reshape(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)