mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Move Decorrelate module to after encoder, with scale 0.02->0.1
This commit is contained in:
parent
4a5143e548
commit
e891a65735
@ -94,6 +94,9 @@ class Conformer(EncoderInterface):
|
||||
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
|
||||
)
|
||||
|
||||
self.decorrelate = Decorrelate(d_model, scale=0.1)
|
||||
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -129,6 +132,8 @@ class Conformer(EncoderInterface):
|
||||
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C)
|
||||
|
||||
x = self.decorrelate(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return x, lengths
|
||||
@ -198,7 +203,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.decorrelate = Decorrelate(d_model, scale=0.02)
|
||||
|
||||
|
||||
def forward(
|
||||
@ -263,8 +267,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
src = self.decorrelate(src)
|
||||
|
||||
if alpha != 1.0:
|
||||
src = alpha * src + (1 - alpha) * src_orig
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user