Remove query in AttentionDownsample, rename to SimpleDownsample.

This commit is contained in:
Daniel Povey 2022-12-17 13:44:08 +08:00
parent 35b63c1387
commit ed7e01448c

View File

@ -798,7 +798,7 @@ class AttentionDownsample(torch.nn.Module):
Require out_channels > in_channels.
"""
super(AttentionDownsample, self).__init__()
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
self.bias = nn.Parameter(torch.zeros(downsample))
self.name = None # will be set from training code
@ -833,24 +833,10 @@ class AttentionDownsample(torch.nn.Module):
assert src.shape[0] == d_seq_len * ds
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
# scores: (d_seq_len, downsample, batch_size, 1)
scores = (src * self.query).sum(dim=-1, keepdim=True)
scores = scores + self.bias.unsqueeze(-1).unsqueeze(-1)
scores = penalize_abs_values_gt(scores,
limit=20.0,
penalty=1.0e-04,
name=self.name)
dropout = float(self.dropout)
if dropout > 0.0:
# the 0:1, done on the axis of size 'downsample', selects just
# one dimension while keeping the dim. We'll then broadcast when
# we multiply.
dropout_mask = torch.rand_like(scores[:, 0:1]) > dropout
scores = scores * dropout_mask
weights = scores.softmax(dim=1)
weights = self.bias.softmax(dim=0)
# weights: (downsample, 1, 1)
weights = weights.unsqueeze(-1).unsqueeze(-1)
# ans1 is the first `in_channels` channels of the output
ans = (src * weights).sum(dim=1)