mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove 4 layers.
This commit is contained in:
parent
2e0111e6ef
commit
a6fb9772a8
@ -106,7 +106,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-encoder-layers",
|
"--num-encoder-layers",
|
||||||
type=str,
|
type=str,
|
||||||
default="2,5,5,4,4,5",
|
default="2,4,4,3,3,4",
|
||||||
help="Number of zipformer encoder layers per stack, comma separated.",
|
help="Number of zipformer encoder layers per stack, comma separated.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1684,6 +1684,9 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
)
|
)
|
||||||
out_height = (((in_channels - 1) // 2) - 1) // 2
|
out_height = (((in_channels - 1) // 2) - 1) // 2
|
||||||
|
|
||||||
|
self.scale = nn.Parameter(torch.ones(out_height * layer3_channels))
|
||||||
|
self.scale_max = 1.0
|
||||||
|
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.01))
|
||||||
|
|
||||||
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
|
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob())
|
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob())
|
||||||
@ -1709,6 +1712,11 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
|
|
||||||
x = x.transpose(1, 2).reshape(b, t, c * f)
|
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||||
# now x: (N, ((T-1)//2 - 1))//2, out_height * layer3_channels))
|
# now x: (N, ((T-1)//2 - 1))//2, out_height * layer3_channels))
|
||||||
|
|
||||||
|
x = x * limit_param_value(self.scale,
|
||||||
|
min=float(self.scale_min),
|
||||||
|
max=float(self.scale_max))
|
||||||
|
|
||||||
x = self.out(x)
|
x = self.out(x)
|
||||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user