Remove bias from SimpleUpsample, add one to AttentionDownsample

This commit is contained in:
Daniel Povey 2022-12-16 17:59:15 +08:00
parent bc002a9eda
commit 4eb3e97848

View File

@ -799,6 +799,7 @@ class AttentionDownsample(torch.nn.Module):
"""
super(AttentionDownsample, self).__init__()
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
self.bias = nn.Parameter(torch.zeros(downsample))
self.name = None # will be set from training code
self.dropout = copy.deepcopy(dropout)
@ -832,8 +833,9 @@ class AttentionDownsample(torch.nn.Module):
assert src.shape[0] == d_seq_len * ds
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
# scores: (d_seq_len, downsample, batch_size)
# scores: (d_seq_len, downsample, batch_size, 1)
scores = (src * self.query).sum(dim=-1, keepdim=True)
scores = scores + self.bias.unsqueeze(-1).unsqueeze(-1)
scores = penalize_abs_values_gt(scores,
limit=20.0,
@ -869,7 +871,7 @@ class SimpleUpsample(torch.nn.Module):
num_channels: int,
upsample: int):
super(SimpleUpsample, self).__init__()
self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01)
self.upsample = upsample
def forward(self,
src: Tensor) -> Tensor:
@ -878,10 +880,9 @@ class SimpleUpsample(torch.nn.Module):
Returns a tensor of shape
( (seq_len*upsample), batch_size, num_channels)
"""
upsample = self.bias.shape[0]
upsample = self.upsample
(seq_len, batch_size, num_channels) = src.shape
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
src = src + self.bias.unsqueeze(1)
src = src.reshape(seq_len * upsample, batch_size, num_channels)
return src