mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Reduce initial_scale.
This commit is contained in:
parent
b7b2d8970b
commit
db7a3b6eea
@ -421,7 +421,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
||||||
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True)
|
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25)
|
||||||
|
|
||||||
# linear transformation for positional encoding.
|
# linear transformation for positional encoding.
|
||||||
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
||||||
@ -869,7 +869,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
initial_scale=0.5
|
initial_scale=0.25
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user