Instead of a pooling operation, use the first bottleneck_dim dimensions of the preceding self_attn.forward2 as the input to the squeeze-excite module.
This commit is contained in:
parent
c27ee8cfcf
commit
0d94783e76
@ -449,7 +449,8 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.feed_forward2(src)
|
src = src + self.feed_forward2(src)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
src = src + self.self_attn.forward2(src, attn_weights)
|
self_attn_output2 = self.self_attn.forward2(src, attn_weights)
|
||||||
|
src = src + self_attn_output2
|
||||||
|
|
||||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||||
src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask)
|
src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask)
|
||||||
@ -457,9 +458,10 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.feed_forward3(src)
|
src = src + self.feed_forward3(src)
|
||||||
|
|
||||||
# pooling module
|
# pooling module
|
||||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
src = src + self.squeeze_excite(src,
|
src = src + self.squeeze_excite(src,
|
||||||
key_padding_mask=src_key_padding_mask)
|
self_attn_output2[...,:self.squeeze_excite.bottleneck_dim])
|
||||||
|
|
||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
@ -1451,9 +1453,7 @@ class ModifiedSEModule(nn.Module):
|
|||||||
d_model: int,
|
d_model: int,
|
||||||
bottleneck_dim: int = 16):
|
bottleneck_dim: int = 16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.squeeze_proj = nn.Linear(d_model, bottleneck_dim,
|
self.bottleneck_dim = bottleneck_dim
|
||||||
bias=False)
|
|
||||||
|
|
||||||
|
|
||||||
self.in_proj = nn.Linear(d_model, d_model,
|
self.in_proj = nn.Linear(d_model, d_model,
|
||||||
bias=False)
|
bias=False)
|
||||||
@ -1488,33 +1488,22 @@ class ModifiedSEModule(nn.Module):
|
|||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
key_padding_mask):
|
bottleneck: Tensor):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: a Tensor of shape (T, N, C)
|
x: a Tensor of shape (T, N, C)
|
||||||
key_padding_mask: a Tensor of bool, of shape (N, T), with True in masked
|
bottleneck: a Tensor of shape (1, N, bottleneck_dim) or (T, N, bottleneck_dim) that has
|
||||||
positions.
|
undergone some form of aggregation over time, e.g. attention.
|
||||||
Returns:
|
Returns:
|
||||||
a Tensor of shape (1, N, C)
|
a Tensor of shape (1, N, C)
|
||||||
"""
|
"""
|
||||||
if key_padding_mask is not None:
|
bottleneck = self.balancer(bottleneck)
|
||||||
pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T)
|
bottleneck = self.activation(bottleneck)
|
||||||
pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True))
|
scales = self.from_bottleneck_proj(bottleneck)
|
||||||
pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1)
|
|
||||||
# now pooling_mask: (T, N, 1)
|
|
||||||
else:
|
|
||||||
num_frames = x.shape[0]
|
|
||||||
pooling_mask = 1.0 / num_frames
|
|
||||||
|
|
||||||
squeezed = (x * pooling_mask).sum(dim=0, keepdim=True)
|
|
||||||
squeezed = self.squeeze_proj(squeezed)
|
|
||||||
squeezed = self.balancer(squeezed)
|
|
||||||
squeezed = self.activation(squeezed)
|
|
||||||
squeezed = self.from_bottleneck_proj(squeezed)
|
|
||||||
|
|
||||||
|
|
||||||
x = self.in_proj(x)
|
x = self.in_proj(x)
|
||||||
x = x * squeezed
|
x = x * scales
|
||||||
return self.out_whiten(self.out_proj(x))
|
return self.out_whiten(self.out_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user