mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Move Decorrelate module to after encoder, with scale 0.02->0.1
This commit is contained in:
parent
4a5143e548
commit
e891a65735
@ -94,6 +94,9 @@ class Conformer(EncoderInterface):
|
|||||||
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
|
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.decorrelate = Decorrelate(d_model, scale=0.1)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -129,6 +132,8 @@ class Conformer(EncoderInterface):
|
|||||||
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
||||||
) # (T, N, C)
|
) # (T, N, C)
|
||||||
|
|
||||||
|
x = self.decorrelate(x)
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
return x, lengths
|
return x, lengths
|
||||||
@ -198,7 +203,6 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
self.decorrelate = Decorrelate(d_model, scale=0.02)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -263,8 +267,6 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
src = self.decorrelate(src)
|
|
||||||
|
|
||||||
if alpha != 1.0:
|
if alpha != 1.0:
|
||||||
src = alpha * src + (1 - alpha) * src_orig
|
src = alpha * src + (1 - alpha) * src_orig
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user