mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Replace the 1st of the ConvolutionModules with NonlinAttentionModule
This commit is contained in:
parent
eb6e2b5a1d
commit
47f42ef5db
@ -343,10 +343,12 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
feedforward_dim,
|
||||
dropout)
|
||||
|
||||
self.conv_module1 = ConvolutionModule(d_model,
|
||||
cnn_module_kernel)
|
||||
#self.conv_module1 = ConvolutionModule(d_model,
|
||||
#cnn_module_kernel)
|
||||
self.nonlin_attention_module = NonlinAttentionModule(d_model)
|
||||
|
||||
self.conv_module2 = ConvolutionModule(d_model,
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model,
|
||||
cnn_module_kernel)
|
||||
|
||||
|
||||
@ -444,27 +446,29 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
src = src + src_att
|
||||
|
||||
# convolution module
|
||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||
src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask)
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.nonlin_attention_module(src,
|
||||
attn_weights,
|
||||
head_idx=0)
|
||||
|
||||
src = src + self.feed_forward2(src)
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.squeeze_excite1(src, attn_weights, attn_weights_idx=0)
|
||||
src = src + self.squeeze_excite1(src, attn_weights, head_idx=1)
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
self_attn_output2 = self.self_attn.forward2(src, attn_weights)
|
||||
src = src + self_attn_output2
|
||||
|
||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||
src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask)
|
||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
src = src + self.feed_forward3(src)
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.squeeze_excite2(src, attn_weights, attn_weights_idx=1)
|
||||
src = src + self.squeeze_excite2(src, attn_weights, head_idx=2)
|
||||
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
@ -1495,19 +1499,20 @@ class ModifiedSEModule(nn.Module):
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
attn_weights: Tensor,
|
||||
attn_weights_idx: int):
|
||||
head_idx: int):
|
||||
"""
|
||||
Args:
|
||||
x: a Tensor of shape (T, N, C)
|
||||
attn_weights: a Tensor of shape (N * num_heads, seq_len, seq_len), we will only use the 1st head.
|
||||
attn_weights_idx: indicates which head to choose from attn_weights
|
||||
attn_weights: a Tensor of shape (N * num_heads, seq_len, seq_len), we will only use the head indexed
|
||||
`attn_weights_index`
|
||||
head_idx: indicates which head to choose from attn_weights
|
||||
Returns:
|
||||
a Tensor of shape (T, N, C)
|
||||
"""
|
||||
(T, N, d_model) = x.shape
|
||||
num_heads = attn_weights.shape[0] // N
|
||||
attn_weights = attn_weights.reshape(N, num_heads, T, T)
|
||||
attn_weights = attn_weights[:,attn_weights_idx] # (N, T, T)
|
||||
attn_weights = attn_weights[:,head_idx] # (N, T, T)
|
||||
|
||||
bottleneck = self.to_bottleneck_proj(x) # (T, N, C)
|
||||
bottleneck = bottleneck.transpose(0, 1) # (N, T, bottleneck_dim)
|
||||
@ -1552,6 +1557,80 @@ class FeedforwardModule(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class NonlinAttentionModule(nn.Module):
|
||||
"""This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
|
||||
from the attention module) in palce of actual convolution.
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# to_scale and to_value are analogous to pointwise_conv1 in ConvolutionModule
|
||||
# we make them separate because we need an extra degree of freedom for the
|
||||
# scale, as the attention weights are constrained to sum to one so cannot
|
||||
# provide the degree of freedom for the scale of the features before
|
||||
# self.activation().
|
||||
self.to_scale = nn.Linear(channels, channels, bias=True)
|
||||
self.to_value = nn.Linear(channels, channels, bias=True)
|
||||
|
||||
|
||||
# deriv_balancer corresponds to deriv_balancer2 in ConvolutionMOdule
|
||||
self.deriv_balancer = ActivationBalancer(
|
||||
channels, channel_dim=1,
|
||||
min_positive=0.05, max_positive=1.0,
|
||||
max_abs=20.0,
|
||||
)
|
||||
|
||||
self.activation = DoubleSwish()
|
||||
|
||||
self.out_proj = ScaledLinear(channels, channels,
|
||||
bias=True,
|
||||
initial_scale=0.05)
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
attn_weights: Tensor,
|
||||
head_idx: int,
|
||||
) -> Tensor:
|
||||
""".
|
||||
Args:
|
||||
x: a Tensor of shape (T, N, C), i.e. (time, batch, channels)
|
||||
attn_weights: a Tensor of shape (N * num_heads, seq_len, seq_len), we will only use the 1st head.
|
||||
head_idx: indicates which head to choose from attn_weights
|
||||
Returns:
|
||||
a Tensor of shape (T, N, C)
|
||||
"""
|
||||
|
||||
s = self.to_scale(x)
|
||||
v = self.to_value(x)
|
||||
if self.training and random.random() < 0.02:
|
||||
# prevent the inputs to the sigmoid from getting very large (this is
|
||||
# unlikely to happen in this particular module, so giving this path
|
||||
# a very small probability).
|
||||
s = penalize_abs_values_gt(s, limit=20.0, penalty=1.0e-04)
|
||||
|
||||
# GLU mechanism
|
||||
x = s.sigmoid() * v
|
||||
|
||||
(T, N, d_model) = x.shape
|
||||
num_heads = attn_weights.shape[0] // N
|
||||
attn_weights = attn_weights.reshape(N, num_heads, T, T)
|
||||
attn_weights = attn_weights[:,head_idx] # (N, T, T)
|
||||
x = x.transpose(0, 1) # (N, T, C)
|
||||
x = torch.bmm(attn_weights, x)
|
||||
x = self.deriv_balancer(x)
|
||||
x = x.transpose(0, 1) # (T, N, C)
|
||||
x = self.activation(x)
|
||||
x = self.out_proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Zipformer model.
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user