mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Double the threshold in brelu; slightly increase max_factor.
This commit is contained in:
parent
74f2b163de
commit
65b09dd5f2
@ -47,15 +47,15 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||||
),
|
),
|
||||||
DerivBalancer(channel_dim=1, threshold=0.02,
|
DerivBalancer(channel_dim=1, threshold=0.05,
|
||||||
max_factor=0.02),
|
max_factor=0.025),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
ExpScale(odim, 1, 1, speed=20.0),
|
ExpScale(odim, 1, 1, speed=20.0),
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||||
),
|
),
|
||||||
DerivBalancer(channel_dim=1, threshold=0.02,
|
DerivBalancer(channel_dim=1, threshold=0.05,
|
||||||
max_factor=0.02),
|
max_factor=0.025),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
ExpScale(odim, 1, 1, speed=20.0),
|
ExpScale(odim, 1, 1, speed=20.0),
|
||||||
)
|
)
|
||||||
|
@ -156,8 +156,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
DerivBalancer(channel_dim=-1, threshold=0.02,
|
DerivBalancer(channel_dim=-1, threshold=0.05,
|
||||||
max_factor=0.02),
|
max_factor=0.025),
|
||||||
ExpScaleSwish(dim_feedforward, speed=20.0),
|
ExpScaleSwish(dim_feedforward, speed=20.0),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(dim_feedforward, d_model),
|
nn.Linear(dim_feedforward, d_model),
|
||||||
@ -165,8 +165,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
DerivBalancer(channel_dim=-1, threshold=0.02,
|
DerivBalancer(channel_dim=-1, threshold=0.05,
|
||||||
max_factor=0.02),
|
max_factor=0.025),
|
||||||
ExpScaleSwish(dim_feedforward, speed=20.0),
|
ExpScaleSwish(dim_feedforward, speed=20.0),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(dim_feedforward, d_model),
|
nn.Linear(dim_feedforward, d_model),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user