mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Have 2 squeeze-excite modules per layer, using different attention heads.
This commit is contained in:
parent
efbe20694f
commit
eb6e2b5a1d
@ -350,7 +350,8 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
cnn_module_kernel)
|
||||
|
||||
|
||||
self.squeeze_excite = ModifiedSEModule(d_model)
|
||||
self.squeeze_excite1 = ModifiedSEModule(d_model)
|
||||
self.squeeze_excite2 = ModifiedSEModule(d_model)
|
||||
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
|
||||
@ -448,6 +449,10 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
src = src + self.feed_forward2(src)
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.squeeze_excite1(src, attn_weights, attn_weights_idx=0)
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
self_attn_output2 = self.self_attn.forward2(src, attn_weights)
|
||||
src = src + self_attn_output2
|
||||
@ -459,8 +464,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.squeeze_excite(src,
|
||||
attn_weights)
|
||||
src = src + self.squeeze_excite2(src, attn_weights, attn_weights_idx=1)
|
||||
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
@ -1490,19 +1494,20 @@ class ModifiedSEModule(nn.Module):
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
attn_weights: Tensor):
|
||||
attn_weights: Tensor,
|
||||
attn_weights_idx: int):
|
||||
"""
|
||||
Args:
|
||||
x: a Tensor of shape (T, N, C)
|
||||
attn_weights: a Tensor of shape (N * num_heads, seq_len, seq_len), we will only use the 1st head.
|
||||
|
||||
attn_weights_idx: indicates which head to choose from attn_weights
|
||||
Returns:
|
||||
a Tensor of shape (T, N, C)
|
||||
"""
|
||||
(T, N, d_model) = x.shape
|
||||
num_heads = attn_weights.shape[0] // N
|
||||
attn_weights = attn_weights.reshape(N, num_heads, T, T)
|
||||
attn_weights = attn_weights[:,0] # (N, T, T)
|
||||
attn_weights = attn_weights[:,attn_weights_idx] # (N, T, T)
|
||||
|
||||
bottleneck = self.to_bottleneck_proj(x) # (T, N, C)
|
||||
bottleneck = bottleneck.transpose(0, 1) # (N, T, bottleneck_dim)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user