mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Add deriv-balancer at output of embedding.
This commit is contained in:
parent
2e6d170be8
commit
1962fe298b
@ -57,6 +57,8 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
)
|
)
|
||||||
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||||
self.out_norm = BasicNorm(odim)
|
self.out_norm = BasicNorm(odim)
|
||||||
|
# constrain mean of output to be close to zero.
|
||||||
|
self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6)
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
|
|
||||||
def _reset_parameters(self):
|
def _reset_parameters(self):
|
||||||
@ -84,6 +86,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||||
x = self.out_norm(x)
|
x = self.out_norm(x)
|
||||||
|
x = self.out_balancer(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup",
|
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
Loading…
x
Reference in New Issue
Block a user