diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 79a6ebefb..9facae5ce 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -478,9 +478,12 @@ class AttentionDownsample(torch.nn.Module): self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) # fill in the extra dimensions with a projection of the input - self.extra_proj = nn.Linear(in_channels * downsample, - out_channels - in_channels, - bias=False) + if out_channels > in_channels: + self.extra_proj = nn.Linear(in_channels * downsample, + out_channels - in_channels, + bias=False) + else: + self.extra_proj = None self.downsample = downsample def forward(self, @@ -509,10 +512,10 @@ class AttentionDownsample(torch.nn.Module): # ans1 is the first `in_channels` channels of the output ans1 = (src * weights).sum(dim=1) src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels) - ans2 = self.extra_proj(src) - - ans = torch.cat((ans1, ans2), dim=2) + if self.extra_proj is not None: + ans2 = self.extra_proj(src) + ans = torch.cat((ans1, ans2), dim=2) return ans @@ -551,7 +554,7 @@ class SimpleCombiner(torch.nn.Module): dim1: int, dim2: int): super(SimpleCombiner, self).__init__() - assert dim2 > dim1 + assert dim2 >= dim1 self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01) self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 4d87d3e73..7d9ec647e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -113,7 +113,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--encoder-dims", type=str, - default="320,512,512", + default="384,384", help="Attention dimension in 2, blocks of conformer encoder layers, comma separated, " "and the output dim of the encoder", ) @@ -121,7 +121,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--conformer-subsampling-factor", type=int, - default=4, + default=1, help="Subsampling factor for 2nd stack of encoder layers.", )