mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Do some changes for modified attention.
This commit is contained in:
parent
9142bbb17d
commit
309461c185
@ -181,6 +181,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
# define layernorm for conv1d_abs
|
||||
self.norm_conv_abs = nn.LayerNorm(d_model)
|
||||
self.layernorm = nn.LayerNorm(d_model)
|
||||
|
||||
self.ff_scale = 0.5
|
||||
|
||||
@ -193,14 +194,17 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
self.in_channel = 1024
|
||||
self.out_channel = 64
|
||||
self.kernel_size = 21
|
||||
self.padding = 10
|
||||
self.linear1 = nn.Linear(512, self.in_channel)
|
||||
self.conv1d_abs = Conv1dAbs(self.in_channel, self.out_channel, kernel_size=self.kernel_size, padding=self.padding)
|
||||
self.activation = nn.ReLU()
|
||||
self.linear2 = nn.Linear(self.out_channel, 512)
|
||||
self.padding = int((self.kernel_size - 1) / 2)
|
||||
self.conv1d_channels = 768
|
||||
self.linear1 = nn.Linear(512, self.conv1d_channels)
|
||||
self.conv1d_abs = Conv1dAbs(
|
||||
self.conv1d_channels,
|
||||
self.conv1d_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.padding,
|
||||
)
|
||||
self.linear2 = nn.Linear(self.conv1d_channels, 512)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -236,17 +240,33 @@ class ConformerEncoderLayer(nn.Module):
|
||||
if not self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
|
||||
# modified-attention module
|
||||
inf = torch.tensor(float("inf"), device=src.device)
|
||||
|
||||
def check_inf(x):
|
||||
if x.max() == inf:
|
||||
print("Error: inf found: ", x)
|
||||
assert 0
|
||||
|
||||
# modified-attention module
|
||||
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_conv_abs(src)
|
||||
src = self.linear1(src)
|
||||
src = torch.exp(src.clamp(max=75))
|
||||
|
||||
src = self.linear1(src * 0.25)
|
||||
src = torch.exp(src.clamp(min=-75, max=75))
|
||||
check_inf(src)
|
||||
src = src.permute(1, 2, 0)
|
||||
src = self.conv1d_abs(src) / self.kernel_size
|
||||
src = self.activation(src).permute(2, 0, 1)
|
||||
src = torch.log(src)
|
||||
check_inf(src)
|
||||
src = src.permute(2, 0, 1)
|
||||
src = torch.log(src.clamp(min=1e-20))
|
||||
check_inf(src)
|
||||
src = self.linear2(src)
|
||||
src = self.layernorm(src)
|
||||
# multipy the output by 0.5 later.
|
||||
# do a comparison.
|
||||
|
||||
src = residual + self.dropout(src)
|
||||
if not self.normalize_before:
|
||||
src = self.norm_conv_abs(src)
|
||||
@ -387,8 +407,8 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
||||
|
||||
# Reserve the order of positive indices and concat both positive and
|
||||
# negative indices. This is used to support the shifting trick
|
||||
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
# negative indices. This is used to support the shifting trick as in "T
|
||||
# ransformer-XL:Attentive Language Models Beyond a Fixed-Length Context"
|
||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user