Simplify downsampling and upsampling

This commit is contained in:
Daniel Povey 2022-09-28 13:49:11 +08:00
parent 01af88c2f6
commit 10a3061025

View File

@ -108,14 +108,12 @@ class Conformer(EncoderInterface):
dropout=dropout
),
input_dim=d_model[0],
module_dim=d_model[1],
output_dim=d_model[1],
downsample=conformer_subsampling_factor,
)
self.out_proj = ScaledLinear(
d_model[0] + d_model[1], d_model[2],
bias=False)
self.out_combiner = SimpleCombiner(d_model[0],
d_model[1])
def forward(
@ -158,12 +156,10 @@ class Conformer(EncoderInterface):
x1_no_combine, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C) where C == d_model[1]
x = torch.cat((x1, x2), dim=2)
x = self.out_combiner(x1, x2)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
x = self.out_proj(x)
return x, lengths
@ -415,29 +411,14 @@ class DownsampledConformerEncoder(nn.Module):
def __init__(self,
encoder: nn.Module,
input_dim: int,
module_dim: int,
output_dim: int,
downsample: int):
super(DownsampledConformerEncoder, self).__init__()
self.downsample_factor = downsample
# note: we'll pad manually.
self.downsample = nn.Conv1d(
input_dim,
module_dim,
kernel_size=downsample,
stride=downsample,
padding=0)
self.downsample = AttentionDownsample(input_dim, output_dim, downsample)
self.encoder = encoder
self.upsample = SimpleUpsample(output_dim, downsample)
self.upsample = nn.ConvTranspose1d(
module_dim,
output_dim,
kernel_size=downsample,
stride=downsample,
padding=0)
def forward(self,
src: Tensor,
@ -462,10 +443,56 @@ class DownsampledConformerEncoder(nn.Module):
Returns: output of shape (S, N, F) where F is the number of output features
(output_dim to constructor)
"""
(seq_len, batch_size, embedding_dim) = src.shape
src_orig = src
src = self.downsample(src)
ds = self.downsample_factor
if mask is not None:
mask = mask[::ds,::ds]
if src_key_padding_mask is not None:
src_key_padding_mask = src_key_padding_mask[::ds]
src, _src_no_combine = self.encoder(
src, src_key_padding_mask=mask, warmup=warmup
)
src = self.upsample(src)
return src
class AttentionDownsample(torch.nn.Module):
"""
Does downsampling with attention, by weighted sum, and a projection..
"""
def __init__(self,
in_channels: int,
out_channels: int,
downsample: int):
"""
Require out_channels > in_channels.
"""
super(AttentionDownsample, self).__init__()
assert out_channels > in_channels
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
# fill in the extra dimensions with a projection of the input
self.extra_proj = nn.Linear(in_channels * downsample,
out_channels - in_channels,
bias=False)
self.downsample = downsample
def forward(self,
src: Tensor) -> Tensor:
"""
x: (seq_len, batch_size, in_channels)
Returns a tensor of shape
( (seq_len+downsample-1)//downsample, batch_size, out_channels)
"""
(seq_len, batch_size, in_channels) = src.shape
ds = self.downsample
d_seq_len = (seq_len + ds - 1) // ds
src_orig = src
# Pad to an exact multiple of self.downsample
if seq_len != d_seq_len * ds:
# right-pad src, repeating the last element.
pad = d_seq_len * ds - seq_len
@ -473,31 +500,83 @@ class DownsampledConformerEncoder(nn.Module):
src = torch.cat((src, src_extra), dim=0)
assert src.shape[0] == d_seq_len * ds
if mask is not None:
mask = mask[::ds,::ds]
if src_key_padding_mask is not None:
src_key_padding_mask = src_key_padding_mask[::ds]
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
scores = (src * self.query).sum(dim=-1, keepdim=True)
weights = scores.softmax(dim=1)
src = src.permute(1, 2, 0) # (#batch, channels, time).
src = self.downsample(src)
src = src.permute(2, 0, 1) # (time, batch, channels)
# ans1 is the first `in_channels` channels of the output
ans1 = (src * weights).sum(dim=1)
src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels)
ans2 = self.extra_proj(src)
src, _src_no_combine = self.encoder(
src, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C)
ans = torch.cat((ans1, ans2), dim=2)
return ans
src = src.permute(1, 2, 0) # (#batch, channels, time).
src = self.upsample(src)
src = src.permute(2, 0, 1) # (time, batch, channels)
new_seq_len = src.shape[0]
assert new_seq_len >= seq_len
if new_seq_len > seq_len:
src = src[:seq_len]
class SimpleUpsample(torch.nn.Module):
"""
A very simple form of upsampling that mostly just repeats the input, but
also adds a position-specific bias.
"""
def __init__(self,
num_channels: int,
upsample: int):
super(SimpleUpsample, self).__init__()
self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01)
def forward(self,
src: Tensor) -> Tensor:
"""
x: (seq_len, batch_size, num_channels)
Returns a tensor of shape
( (seq_len*upsample), batch_size, num_channels)
"""
upsample = self.bias.shape[0]
(seq_len, batch_size, num_channels) = src.shape
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
src = src + self.bias.unsqueeze(1)
src = src.reshape(seq_len * upsample, batch_size, num_channels)
return src
class SimpleCombiner(torch.nn.Module):
"""
A very simple way of combining 2 vectors of 2 different dims, via a
learned weighted combination in the shared part of the dim.
"""
def __init__(self,
dim1: int,
dim2: int):
super(SimpleCombiner, self).__init__()
assert dim2 > dim1
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
def forward(self,
src1: Tensor,
src2: Tensor) -> Tensor:
"""
src1: (*, dim1)
src2: (*, dim2)
Returns: a tensor of shape (*, dim2)
"""
assert src1.shape[:-1] == src2.shape[:-1]
dim1 = src1.shape[-1]
dim2 = src2.shape[-1]
weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True)
weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True)
weight = (weight1 + weight2).sigmoid()
src2_part1 = src2[...,:dim1]
part1 = src1 * weight + src2_part1 * (1.0 - weight)
part2 = src2[...,dim1:]
return torch.cat((part1, part2), dim=-1)
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.