mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
Add identity pre_norm_final for diagnostics.
This commit is contained in:
parent
2d3a76292d
commit
7eb5a84cbe
@ -181,6 +181,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5)
|
self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5)
|
||||||
self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5)
|
self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5)
|
||||||
|
|
||||||
|
self.pre_norm_final = Identity()
|
||||||
self.norm_final = BasicNorm(d_model)
|
self.norm_final = BasicNorm(d_model)
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
@ -244,7 +245,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
# feed forward module
|
# feed forward module
|
||||||
src = src + self.dropout(self.feed_forward(self.scale_ff(src)))
|
src = src + self.dropout(self.feed_forward(self.scale_ff(src)))
|
||||||
|
|
||||||
src = self.norm_final(src)
|
src = self.norm_final(self.pre_norm_final(src))
|
||||||
|
|
||||||
return src
|
return src
|
||||||
|
|
||||||
@ -930,8 +931,9 @@ class SwishOffset(torch.nn.Module):
|
|||||||
return x * torch.sigmoid(x + self.offset)
|
return x * torch.sigmoid(x + self.offset)
|
||||||
|
|
||||||
|
|
||||||
def identity(x):
|
class Identity(torch.nn.Module):
|
||||||
return x
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class RandomCombine(torch.nn.Module):
|
class RandomCombine(torch.nn.Module):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user