mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Will try to add a weight 0.01 before the exp operation
This commit is contained in:
parent
99274cbb8f
commit
ae667f4801
@ -199,15 +199,16 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
self.kernel_size = 31
|
||||
self.padding = int((self.kernel_size - 1) / 2)
|
||||
self.conv1d_channels = 768
|
||||
self.linear1 = nn.Linear(512, self.conv1d_channels)
|
||||
self.in_conv1d_channels = 768
|
||||
self.out_conv1d_channels = 768
|
||||
self.linear1 = nn.Linear(512, self.in_conv1d_channels)
|
||||
self.conv1d_abs = Conv1dAbs(
|
||||
self.conv1d_channels,
|
||||
self.conv1d_channels,
|
||||
self.in_conv1d_channels,
|
||||
self.out_conv1d_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.padding,
|
||||
)
|
||||
self.linear2 = nn.Linear(self.conv1d_channels, 512)
|
||||
self.linear2 = nn.Linear(self.out_conv1d_channels, 512)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -260,27 +261,20 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src = self.norm_mha(src)
|
||||
|
||||
# conv1dabs modified attention module
|
||||
inf = torch.tensor(float("inf"), device=src.device)
|
||||
def check_inf(x):
|
||||
if x.max() == inf:
|
||||
print("Error: inf found: ", x)
|
||||
assert 0
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_conv_abs(src)
|
||||
src = self.linear1(src * 0.25)
|
||||
|
||||
# src = self.linear1(src * 0.25)
|
||||
src = 0.01 * self.linear1(src * 0.25)
|
||||
src = torch.exp(src.clamp(min=-75, max=75))
|
||||
check_inf(src)
|
||||
src = src.permute(1, 2, 0)
|
||||
src = self.conv1d_abs(src) / self.kernel_size
|
||||
check_inf(src)
|
||||
src = src.permute(2, 0, 1)
|
||||
src = torch.log(src.clamp(min=1e-20))
|
||||
check_inf(src)
|
||||
src = self.linear2(src)
|
||||
src = self.layernorm(src)
|
||||
# multipy the output by 0.5 later.
|
||||
# do a comparison.
|
||||
src = 0.25 * self.layernorm(src)
|
||||
|
||||
src = residual + self.dropout(src)
|
||||
if not self.normalize_before:
|
||||
src = self.norm_conv_abs(src)
|
||||
@ -871,6 +865,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
else:
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
Modified from
|
||||
|
Loading…
x
Reference in New Issue
Block a user