Use the attention weights as input for the ModifiedSEModule
This commit is contained in:
parent
0d94783e76
commit
efbe20694f
@ -460,7 +460,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# pooling module
|
# pooling module
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
src = src + self.squeeze_excite(src,
|
src = src + self.squeeze_excite(src,
|
||||||
self_attn_output2[...,:self.squeeze_excite.bottleneck_dim])
|
attn_weights)
|
||||||
|
|
||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
@ -1458,6 +1458,10 @@ class ModifiedSEModule(nn.Module):
|
|||||||
self.in_proj = nn.Linear(d_model, d_model,
|
self.in_proj = nn.Linear(d_model, d_model,
|
||||||
bias=False)
|
bias=False)
|
||||||
|
|
||||||
|
self.to_bottleneck_proj = ScaledLinear(d_model,
|
||||||
|
bottleneck_dim,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
# Caution: this cannot work correctly with an extremeley small batch size, e.g. if
|
# Caution: this cannot work correctly with an extremeley small batch size, e.g. if
|
||||||
# we were training with a single very long audio sequence, or just 2 or 3 sequences
|
# we were training with a single very long audio sequence, or just 2 or 3 sequences
|
||||||
# at a time. We make max_factor small to reduce the harm this could cause
|
# at a time. We make max_factor small to reduce the harm this could cause
|
||||||
@ -1484,21 +1488,30 @@ class ModifiedSEModule(nn.Module):
|
|||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
bottleneck: Tensor):
|
attn_weights: Tensor):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: a Tensor of shape (T, N, C)
|
x: a Tensor of shape (T, N, C)
|
||||||
bottleneck: a Tensor of shape (1, N, bottleneck_dim) or (T, N, bottleneck_dim) that has
|
attn_weights: a Tensor of shape (N * num_heads, seq_len, seq_len), we will only use the 1st head.
|
||||||
undergone some form of aggregation over time, e.g. attention.
|
|
||||||
Returns:
|
Returns:
|
||||||
a Tensor of shape (1, N, C)
|
a Tensor of shape (T, N, C)
|
||||||
"""
|
"""
|
||||||
|
(T, N, d_model) = x.shape
|
||||||
|
num_heads = attn_weights.shape[0] // N
|
||||||
|
attn_weights = attn_weights.reshape(N, num_heads, T, T)
|
||||||
|
attn_weights = attn_weights[:,0] # (N, T, T)
|
||||||
|
|
||||||
|
bottleneck = self.to_bottleneck_proj(x) # (T, N, C)
|
||||||
|
bottleneck = bottleneck.transpose(0, 1) # (N, T, bottleneck_dim)
|
||||||
|
|
||||||
|
# (N, T, T) x (N, T, bottleneck_dim) -> (N, T, bottleneck_dim)
|
||||||
|
bottleneck = torch.bmm(attn_weights, bottleneck)
|
||||||
bottleneck = self.balancer(bottleneck)
|
bottleneck = self.balancer(bottleneck)
|
||||||
bottleneck = self.activation(bottleneck)
|
bottleneck = self.activation(bottleneck)
|
||||||
|
bottleneck = bottleneck.transpose(0, 1) # (T, N, bottleneck_dim)
|
||||||
scales = self.from_bottleneck_proj(bottleneck)
|
scales = self.from_bottleneck_proj(bottleneck)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user