Will try to add a weight 0.01 before the exp operation

This commit is contained in:
Mingshuang Luo 2022-01-14 14:31:56 +08:00
parent 99274cbb8f
commit ae667f4801

View File

@ -199,15 +199,16 @@ class ConformerEncoderLayer(nn.Module):
self.kernel_size = 31
self.padding = int((self.kernel_size - 1) / 2)
self.conv1d_channels = 768
self.linear1 = nn.Linear(512, self.conv1d_channels)
self.in_conv1d_channels = 768
self.out_conv1d_channels = 768
self.linear1 = nn.Linear(512, self.in_conv1d_channels)
self.conv1d_abs = Conv1dAbs(
self.conv1d_channels,
self.conv1d_channels,
self.in_conv1d_channels,
self.out_conv1d_channels,
kernel_size=self.kernel_size,
padding=self.padding,
)
self.linear2 = nn.Linear(self.conv1d_channels, 512)
self.linear2 = nn.Linear(self.out_conv1d_channels, 512)
def forward(
self,
@ -260,27 +261,20 @@ class ConformerEncoderLayer(nn.Module):
src = self.norm_mha(src)
# conv1dabs modified attention module
inf = torch.tensor(float("inf"), device=src.device)
def check_inf(x):
if x.max() == inf:
print("Error: inf found: ", x)
assert 0
residual = src
if self.normalize_before:
src = self.norm_conv_abs(src)
src = self.linear1(src * 0.25)
# src = self.linear1(src * 0.25)
src = 0.01 * self.linear1(src * 0.25)
src = torch.exp(src.clamp(min=-75, max=75))
check_inf(src)
src = src.permute(1, 2, 0)
src = self.conv1d_abs(src) / self.kernel_size
check_inf(src)
src = src.permute(2, 0, 1)
src = torch.log(src.clamp(min=1e-20))
check_inf(src)
src = self.linear2(src)
src = self.layernorm(src)
# multipy the output by 0.5 later.
# do a comparison.
src = 0.25 * self.layernorm(src)
src = residual + self.dropout(src)
if not self.normalize_before:
src = self.norm_conv_abs(src)
@ -871,6 +865,7 @@ class RelPositionMultiheadAttention(nn.Module):
else:
return attn_output, None
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Modified from