mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Update conformer.py
This commit is contained in:
parent
4392da7235
commit
9142bbb17d
@ -193,10 +193,14 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
self.linear1 = nn.Linear(512, 1024)
|
||||
self.conv1d_abs = Conv1dAbs(1024, 64, kernel_size=21, padding=10)
|
||||
self.in_channel = 1024
|
||||
self.out_channel = 64
|
||||
self.kernel_size = 21
|
||||
self.padding = 10
|
||||
self.linear1 = nn.Linear(512, self.in_channel)
|
||||
self.conv1d_abs = Conv1dAbs(self.in_channel, self.out_channel, kernel_size=self.kernel_size, padding=self.padding)
|
||||
self.activation = nn.ReLU()
|
||||
self.linear2 = nn.Linear(64, 512)
|
||||
self.linear2 = nn.Linear(self.out_channel, 512)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -239,7 +243,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src = self.linear1(src)
|
||||
src = torch.exp(src.clamp(max=75))
|
||||
src = src.permute(1, 2, 0)
|
||||
src = self.conv1d_abs(src)
|
||||
src = self.conv1d_abs(src) / self.kernel_size
|
||||
src = self.activation(src).permute(2, 0, 1)
|
||||
src = torch.log(src)
|
||||
src = self.linear2(src)
|
||||
|
Loading…
x
Reference in New Issue
Block a user