Remove Decorrelate()

This commit is contained in:
Daniel Povey 2022-06-13 16:07:15 +08:00
parent c1f487e36d
commit 7338c60296

View File

@ -29,7 +29,6 @@ from scaling import (
ScaledConv1d,
ScaledConv2d,
ScaledLinear,
Decorrelate,
)
from torch import Tensor, nn
@ -94,8 +93,6 @@ class Conformer(EncoderInterface):
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
)
self.decorrelate = Decorrelate(d_model, scale=0.05)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
@ -132,8 +129,6 @@ class Conformer(EncoderInterface):
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C)
x = self.decorrelate(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths