mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove 4 layers.
This commit is contained in:
parent
2e0111e6ef
commit
a6fb9772a8
@ -106,7 +106,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-encoder-layers",
|
||||
type=str,
|
||||
default="2,5,5,4,4,5",
|
||||
default="2,4,4,3,3,4",
|
||||
help="Number of zipformer encoder layers per stack, comma separated.",
|
||||
)
|
||||
|
||||
|
||||
@ -1684,6 +1684,9 @@ class Conv2dSubsampling(nn.Module):
|
||||
)
|
||||
out_height = (((in_channels - 1) // 2) - 1) // 2
|
||||
|
||||
self.scale = nn.Parameter(torch.ones(out_height * layer3_channels))
|
||||
self.scale_max = 1.0
|
||||
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.01))
|
||||
|
||||
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob())
|
||||
@ -1709,6 +1712,11 @@ class Conv2dSubsampling(nn.Module):
|
||||
|
||||
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||
# now x: (N, ((T-1)//2 - 1))//2, out_height * layer3_channels))
|
||||
|
||||
x = x * limit_param_value(self.scale,
|
||||
min=float(self.scale_min),
|
||||
max=float(self.scale_max))
|
||||
|
||||
x = self.out(x)
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = self.dropout(x)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user