From 7c6d923d3f7550b8ff9469cec952192f7b0ba5ac Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 7 Jun 2022 16:47:54 +0800 Subject: [PATCH] Add decorrelation to joiner --- .../ASR/pruned_transducer_stateless2/decoder.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index b6d94aaf1..eba10f8c7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -17,7 +17,11 @@ import torch import torch.nn as nn import torch.nn.functional as F -from scaling import ScaledConv1d, ScaledEmbedding +from scaling import ( + ScaledConv1d, + ScaledEmbedding, + Decorrelate, +) class Decoder(nn.Module): @@ -59,6 +63,9 @@ class Decoder(nn.Module): embedding_dim=decoder_dim, padding_idx=blank_id, ) + self.decorrelate = Decorrelate(apply_prob=0.25, + dropout_rate=0.05) + self.blank_id = blank_id assert context_size >= 1, context_size @@ -99,5 +106,6 @@ class Decoder(nn.Module): assert embedding_out.size(-1) == self.context_size embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) + embedding_out = self.decorrelate(embedding_out) embedding_out = F.relu(embedding_out) return embedding_out