Add identity pre_norm_final for diagnostics.

This commit is contained in:
Daniel Povey 2022-03-11 21:00:43 +08:00
parent 2d3a76292d
commit 7eb5a84cbe

View File

@ -181,6 +181,7 @@ class ConformerEncoderLayer(nn.Module):
self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5) self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5)
self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5) self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5)
self.pre_norm_final = Identity()
self.norm_final = BasicNorm(d_model) self.norm_final = BasicNorm(d_model)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@ -244,7 +245,7 @@ class ConformerEncoderLayer(nn.Module):
# feed forward module # feed forward module
src = src + self.dropout(self.feed_forward(self.scale_ff(src))) src = src + self.dropout(self.feed_forward(self.scale_ff(src)))
src = self.norm_final(src) src = self.norm_final(self.pre_norm_final(src))
return src return src
@ -930,8 +931,9 @@ class SwishOffset(torch.nn.Module):
return x * torch.sigmoid(x + self.offset) return x * torch.sigmoid(x + self.offset)
def identity(x): class Identity(torch.nn.Module):
return x def forward(self, x: Tensor) -> Tensor:
return x
class RandomCombine(torch.nn.Module): class RandomCombine(torch.nn.Module):