Remove 4 layers.

This commit is contained in:
Daniel Povey 2022-11-27 13:24:30 +08:00
parent 2e0111e6ef
commit a6fb9772a8
2 changed files with 9 additions and 1 deletions

View File

@ -106,7 +106,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--num-encoder-layers", "--num-encoder-layers",
type=str, type=str,
default="2,5,5,4,4,5", default="2,4,4,3,3,4",
help="Number of zipformer encoder layers per stack, comma separated.", help="Number of zipformer encoder layers per stack, comma separated.",
) )

View File

@ -1684,6 +1684,9 @@ class Conv2dSubsampling(nn.Module):
) )
out_height = (((in_channels - 1) // 2) - 1) // 2 out_height = (((in_channels - 1) // 2) - 1) // 2
self.scale = nn.Parameter(torch.ones(out_height * layer3_channels))
self.scale_max = 1.0
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.01))
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels, self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob()) aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob())
@ -1709,6 +1712,11 @@ class Conv2dSubsampling(nn.Module):
x = x.transpose(1, 2).reshape(b, t, c * f) x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, ((T-1)//2 - 1))//2, out_height * layer3_channels)) # now x: (N, ((T-1)//2 - 1))//2, out_height * layer3_channels))
x = x * limit_param_value(self.scale,
min=float(self.scale_min),
max=float(self.scale_max))
x = self.out(x) x = self.out(x)
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.dropout(x) x = self.dropout(x)