Add decorrelation to joiner

This commit is contained in:
Daniel Povey 2022-06-07 16:47:54 +08:00
parent cd6b707e2b
commit 7c6d923d3f

View File

@ -17,7 +17,11 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from scaling import ScaledConv1d, ScaledEmbedding
from scaling import (
ScaledConv1d,
ScaledEmbedding,
Decorrelate,
)
class Decoder(nn.Module):
@ -59,6 +63,9 @@ class Decoder(nn.Module):
embedding_dim=decoder_dim,
padding_idx=blank_id,
)
self.decorrelate = Decorrelate(apply_prob=0.25,
dropout_rate=0.05)
self.blank_id = blank_id
assert context_size >= 1, context_size
@ -99,5 +106,6 @@ class Decoder(nn.Module):
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = self.decorrelate(embedding_out)
embedding_out = F.relu(embedding_out)
return embedding_out