mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement pooling module, add it after initial feedforward.
This commit is contained in:
parent
730e6c8914
commit
5fda800b6d
@ -330,6 +330,8 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
d_model, attention_dim, nhead, pos_dim, dropout=0.0,
|
||||
)
|
||||
|
||||
self.pooling = PoolingModule(d_model)
|
||||
|
||||
self.feed_forward1 = FeedforwardModule(d_model,
|
||||
feedforward_dim,
|
||||
dropout)
|
||||
@ -410,6 +412,10 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
# macaron style feed forward module
|
||||
src = src + self.feed_forward1(src)
|
||||
|
||||
# pooling module
|
||||
src = src + self.pooling(src,
|
||||
key_padding_mask=src_key_padding_mask)
|
||||
|
||||
# multi-headed self-attention module
|
||||
src_att, attn_weights = self.self_attn(
|
||||
src,
|
||||
@ -1384,6 +1390,43 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}")
|
||||
|
||||
|
||||
|
||||
|
||||
class PoolingModule(nn.Module):
|
||||
"""
|
||||
Averages the input over the time dimension and project with a square matrix.
|
||||
"""
|
||||
def __init__(self,
|
||||
d_model: int):
|
||||
super().__init__()
|
||||
self.proj = ScaledLinear(d_model, d_model,
|
||||
initial_scale=0.1, bias=False)
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
key_padding_mask):
|
||||
"""
|
||||
Args:
|
||||
x: a Tensor of shape (T, N, C)
|
||||
key_padding_mask: a Tensor of bool, of shape (N, T), with True in masked
|
||||
positions.
|
||||
Returns:
|
||||
a Tensor of shape (1, N, C)
|
||||
"""
|
||||
if key_padding_mask is not None:
|
||||
pooling_mask = key_padding_mask.logical_not().to(src.dtype) # (N, T)
|
||||
pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True))
|
||||
pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1)
|
||||
# now pooling_mask: (T, N, 1)
|
||||
else:
|
||||
num_frames = x.shape[0]
|
||||
pooling_mask = 1.0 / num_frames
|
||||
|
||||
x = (x * pooling_mask).sum(dim=0, keepdim=True)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class FeedforwardModule(nn.Module):
|
||||
"""Feedforward module in Zipformer model.
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user