mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Will try to add a weight 0.01 before the exp operation
This commit is contained in:
parent
99274cbb8f
commit
ae667f4801
@ -199,15 +199,16 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.kernel_size = 31
|
self.kernel_size = 31
|
||||||
self.padding = int((self.kernel_size - 1) / 2)
|
self.padding = int((self.kernel_size - 1) / 2)
|
||||||
self.conv1d_channels = 768
|
self.in_conv1d_channels = 768
|
||||||
self.linear1 = nn.Linear(512, self.conv1d_channels)
|
self.out_conv1d_channels = 768
|
||||||
|
self.linear1 = nn.Linear(512, self.in_conv1d_channels)
|
||||||
self.conv1d_abs = Conv1dAbs(
|
self.conv1d_abs = Conv1dAbs(
|
||||||
self.conv1d_channels,
|
self.in_conv1d_channels,
|
||||||
self.conv1d_channels,
|
self.out_conv1d_channels,
|
||||||
kernel_size=self.kernel_size,
|
kernel_size=self.kernel_size,
|
||||||
padding=self.padding,
|
padding=self.padding,
|
||||||
)
|
)
|
||||||
self.linear2 = nn.Linear(self.conv1d_channels, 512)
|
self.linear2 = nn.Linear(self.out_conv1d_channels, 512)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -260,27 +261,20 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src = self.norm_mha(src)
|
src = self.norm_mha(src)
|
||||||
|
|
||||||
# conv1dabs modified attention module
|
# conv1dabs modified attention module
|
||||||
inf = torch.tensor(float("inf"), device=src.device)
|
|
||||||
def check_inf(x):
|
|
||||||
if x.max() == inf:
|
|
||||||
print("Error: inf found: ", x)
|
|
||||||
assert 0
|
|
||||||
residual = src
|
residual = src
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
src = self.norm_conv_abs(src)
|
src = self.norm_conv_abs(src)
|
||||||
src = self.linear1(src * 0.25)
|
|
||||||
|
# src = self.linear1(src * 0.25)
|
||||||
|
src = 0.01 * self.linear1(src * 0.25)
|
||||||
src = torch.exp(src.clamp(min=-75, max=75))
|
src = torch.exp(src.clamp(min=-75, max=75))
|
||||||
check_inf(src)
|
|
||||||
src = src.permute(1, 2, 0)
|
src = src.permute(1, 2, 0)
|
||||||
src = self.conv1d_abs(src) / self.kernel_size
|
src = self.conv1d_abs(src) / self.kernel_size
|
||||||
check_inf(src)
|
|
||||||
src = src.permute(2, 0, 1)
|
src = src.permute(2, 0, 1)
|
||||||
src = torch.log(src.clamp(min=1e-20))
|
src = torch.log(src.clamp(min=1e-20))
|
||||||
check_inf(src)
|
|
||||||
src = self.linear2(src)
|
src = self.linear2(src)
|
||||||
src = self.layernorm(src)
|
src = 0.25 * self.layernorm(src)
|
||||||
# multipy the output by 0.5 later.
|
|
||||||
# do a comparison.
|
|
||||||
src = residual + self.dropout(src)
|
src = residual + self.dropout(src)
|
||||||
if not self.normalize_before:
|
if not self.normalize_before:
|
||||||
src = self.norm_conv_abs(src)
|
src = self.norm_conv_abs(src)
|
||||||
@ -871,6 +865,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return attn_output, None
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionModule(nn.Module):
|
class ConvolutionModule(nn.Module):
|
||||||
"""ConvolutionModule in Conformer model.
|
"""ConvolutionModule in Conformer model.
|
||||||
Modified from
|
Modified from
|
||||||
|
Loading…
x
Reference in New Issue
Block a user