mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
Add identity pre_norm_final for diagnostics.
This commit is contained in:
parent
2d3a76292d
commit
7eb5a84cbe
@ -181,6 +181,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5)
|
||||
self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5)
|
||||
|
||||
self.pre_norm_final = Identity()
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
@ -244,7 +245,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
# feed forward module
|
||||
src = src + self.dropout(self.feed_forward(self.scale_ff(src)))
|
||||
|
||||
src = self.norm_final(src)
|
||||
src = self.norm_final(self.pre_norm_final(src))
|
||||
|
||||
return src
|
||||
|
||||
@ -930,7 +931,8 @@ class SwishOffset(torch.nn.Module):
|
||||
return x * torch.sigmoid(x + self.offset)
|
||||
|
||||
|
||||
def identity(x):
|
||||
class Identity(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user