mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Remove Decorrelate()
This commit is contained in:
parent
c1f487e36d
commit
7338c60296
@ -29,7 +29,6 @@ from scaling import (
|
||||
ScaledConv1d,
|
||||
ScaledConv2d,
|
||||
ScaledLinear,
|
||||
Decorrelate,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -94,8 +93,6 @@ class Conformer(EncoderInterface):
|
||||
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
|
||||
)
|
||||
|
||||
self.decorrelate = Decorrelate(d_model, scale=0.05)
|
||||
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||
@ -132,8 +129,6 @@ class Conformer(EncoderInterface):
|
||||
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C)
|
||||
|
||||
x = self.decorrelate(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return x, lengths
|
||||
|
Loading…
x
Reference in New Issue
Block a user