From 7eb5a84cbeb4242736b28d1d1ea5a118cb1cc256 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 21:00:43 +0800 Subject: [PATCH] Add identity pre_norm_final for diagnostics. --- egs/librispeech/ASR/transducer_stateless/conformer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index fa25e6ca0..389a7cb7f 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -181,6 +181,7 @@ class ConformerEncoderLayer(nn.Module): self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5) self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5) + self.pre_norm_final = Identity() self.norm_final = BasicNorm(d_model) self.dropout = nn.Dropout(dropout) @@ -244,7 +245,7 @@ class ConformerEncoderLayer(nn.Module): # feed forward module src = src + self.dropout(self.feed_forward(self.scale_ff(src))) - src = self.norm_final(src) + src = self.norm_final(self.pre_norm_final(src)) return src @@ -930,8 +931,9 @@ class SwishOffset(torch.nn.Module): return x * torch.sigmoid(x + self.offset) -def identity(x): - return x +class Identity(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + return x class RandomCombine(torch.nn.Module):