diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index b6d94aaf1..eba10f8c7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -17,7 +17,11 @@ import torch import torch.nn as nn import torch.nn.functional as F -from scaling import ScaledConv1d, ScaledEmbedding +from scaling import ( + ScaledConv1d, + ScaledEmbedding, + Decorrelate, +) class Decoder(nn.Module): @@ -59,6 +63,9 @@ class Decoder(nn.Module): embedding_dim=decoder_dim, padding_idx=blank_id, ) + self.decorrelate = Decorrelate(apply_prob=0.25, + dropout_rate=0.05) + self.blank_id = blank_id assert context_size >= 1, context_size @@ -99,5 +106,6 @@ class Decoder(nn.Module): assert embedding_out.size(-1) == self.context_size embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) + embedding_out = self.decorrelate(embedding_out) embedding_out = F.relu(embedding_out) return embedding_out