mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove bias from SimpleUpsample, add one to AttentionDownsample
This commit is contained in:
parent
bc002a9eda
commit
4eb3e97848
@ -799,6 +799,7 @@ class AttentionDownsample(torch.nn.Module):
|
||||
"""
|
||||
super(AttentionDownsample, self).__init__()
|
||||
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
|
||||
self.bias = nn.Parameter(torch.zeros(downsample))
|
||||
|
||||
self.name = None # will be set from training code
|
||||
self.dropout = copy.deepcopy(dropout)
|
||||
@ -832,8 +833,9 @@ class AttentionDownsample(torch.nn.Module):
|
||||
assert src.shape[0] == d_seq_len * ds
|
||||
|
||||
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
||||
# scores: (d_seq_len, downsample, batch_size)
|
||||
# scores: (d_seq_len, downsample, batch_size, 1)
|
||||
scores = (src * self.query).sum(dim=-1, keepdim=True)
|
||||
scores = scores + self.bias.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
scores = penalize_abs_values_gt(scores,
|
||||
limit=20.0,
|
||||
@ -869,7 +871,7 @@ class SimpleUpsample(torch.nn.Module):
|
||||
num_channels: int,
|
||||
upsample: int):
|
||||
super(SimpleUpsample, self).__init__()
|
||||
self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01)
|
||||
self.upsample = upsample
|
||||
|
||||
def forward(self,
|
||||
src: Tensor) -> Tensor:
|
||||
@ -878,10 +880,9 @@ class SimpleUpsample(torch.nn.Module):
|
||||
Returns a tensor of shape
|
||||
( (seq_len*upsample), batch_size, num_channels)
|
||||
"""
|
||||
upsample = self.bias.shape[0]
|
||||
upsample = self.upsample
|
||||
(seq_len, batch_size, num_channels) = src.shape
|
||||
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
|
||||
src = src + self.bias.unsqueeze(1)
|
||||
src = src.reshape(seq_len * upsample, batch_size, num_channels)
|
||||
return src
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user