mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Simplify downsampling and upsampling
This commit is contained in:
parent
01af88c2f6
commit
10a3061025
@ -108,14 +108,12 @@ class Conformer(EncoderInterface):
|
||||
dropout=dropout
|
||||
),
|
||||
input_dim=d_model[0],
|
||||
module_dim=d_model[1],
|
||||
output_dim=d_model[1],
|
||||
downsample=conformer_subsampling_factor,
|
||||
)
|
||||
|
||||
self.out_proj = ScaledLinear(
|
||||
d_model[0] + d_model[1], d_model[2],
|
||||
bias=False)
|
||||
self.out_combiner = SimpleCombiner(d_model[0],
|
||||
d_model[1])
|
||||
|
||||
|
||||
def forward(
|
||||
@ -158,12 +156,10 @@ class Conformer(EncoderInterface):
|
||||
x1_no_combine, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C) where C == d_model[1]
|
||||
|
||||
x = torch.cat((x1, x2), dim=2)
|
||||
x = self.out_combiner(x1, x2)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
return x, lengths
|
||||
|
||||
|
||||
@ -415,29 +411,14 @@ class DownsampledConformerEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
encoder: nn.Module,
|
||||
input_dim: int,
|
||||
module_dim: int,
|
||||
output_dim: int,
|
||||
downsample: int):
|
||||
super(DownsampledConformerEncoder, self).__init__()
|
||||
|
||||
self.downsample_factor = downsample
|
||||
|
||||
# note: we'll pad manually.
|
||||
self.downsample = nn.Conv1d(
|
||||
input_dim,
|
||||
module_dim,
|
||||
kernel_size=downsample,
|
||||
stride=downsample,
|
||||
padding=0)
|
||||
|
||||
self.downsample = AttentionDownsample(input_dim, output_dim, downsample)
|
||||
self.encoder = encoder
|
||||
self.upsample = SimpleUpsample(output_dim, downsample)
|
||||
|
||||
self.upsample = nn.ConvTranspose1d(
|
||||
module_dim,
|
||||
output_dim,
|
||||
kernel_size=downsample,
|
||||
stride=downsample,
|
||||
padding=0)
|
||||
|
||||
def forward(self,
|
||||
src: Tensor,
|
||||
@ -462,10 +443,56 @@ class DownsampledConformerEncoder(nn.Module):
|
||||
Returns: output of shape (S, N, F) where F is the number of output features
|
||||
(output_dim to constructor)
|
||||
"""
|
||||
(seq_len, batch_size, embedding_dim) = src.shape
|
||||
|
||||
src_orig = src
|
||||
src = self.downsample(src)
|
||||
ds = self.downsample_factor
|
||||
if mask is not None:
|
||||
mask = mask[::ds,::ds]
|
||||
if src_key_padding_mask is not None:
|
||||
src_key_padding_mask = src_key_padding_mask[::ds]
|
||||
|
||||
src, _src_no_combine = self.encoder(
|
||||
src, src_key_padding_mask=mask, warmup=warmup
|
||||
)
|
||||
src = self.upsample(src)
|
||||
|
||||
return src
|
||||
|
||||
|
||||
class AttentionDownsample(torch.nn.Module):
|
||||
"""
|
||||
Does downsampling with attention, by weighted sum, and a projection..
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
downsample: int):
|
||||
"""
|
||||
Require out_channels > in_channels.
|
||||
"""
|
||||
super(AttentionDownsample, self).__init__()
|
||||
assert out_channels > in_channels
|
||||
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
|
||||
|
||||
# fill in the extra dimensions with a projection of the input
|
||||
self.extra_proj = nn.Linear(in_channels * downsample,
|
||||
out_channels - in_channels,
|
||||
bias=False)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self,
|
||||
src: Tensor) -> Tensor:
|
||||
"""
|
||||
x: (seq_len, batch_size, in_channels)
|
||||
Returns a tensor of shape
|
||||
( (seq_len+downsample-1)//downsample, batch_size, out_channels)
|
||||
"""
|
||||
(seq_len, batch_size, in_channels) = src.shape
|
||||
ds = self.downsample
|
||||
d_seq_len = (seq_len + ds - 1) // ds
|
||||
src_orig = src
|
||||
# Pad to an exact multiple of self.downsample
|
||||
if seq_len != d_seq_len * ds:
|
||||
# right-pad src, repeating the last element.
|
||||
pad = d_seq_len * ds - seq_len
|
||||
@ -473,31 +500,83 @@ class DownsampledConformerEncoder(nn.Module):
|
||||
src = torch.cat((src, src_extra), dim=0)
|
||||
assert src.shape[0] == d_seq_len * ds
|
||||
|
||||
if mask is not None:
|
||||
mask = mask[::ds,::ds]
|
||||
if src_key_padding_mask is not None:
|
||||
src_key_padding_mask = src_key_padding_mask[::ds]
|
||||
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
||||
scores = (src * self.query).sum(dim=-1, keepdim=True)
|
||||
weights = scores.softmax(dim=1)
|
||||
|
||||
src = src.permute(1, 2, 0) # (#batch, channels, time).
|
||||
src = self.downsample(src)
|
||||
src = src.permute(2, 0, 1) # (time, batch, channels)
|
||||
# ans1 is the first `in_channels` channels of the output
|
||||
ans1 = (src * weights).sum(dim=1)
|
||||
src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels)
|
||||
ans2 = self.extra_proj(src)
|
||||
|
||||
|
||||
src, _src_no_combine = self.encoder(
|
||||
src, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C)
|
||||
ans = torch.cat((ans1, ans2), dim=2)
|
||||
return ans
|
||||
|
||||
src = src.permute(1, 2, 0) # (#batch, channels, time).
|
||||
src = self.upsample(src)
|
||||
src = src.permute(2, 0, 1) # (time, batch, channels)
|
||||
|
||||
new_seq_len = src.shape[0]
|
||||
assert new_seq_len >= seq_len
|
||||
if new_seq_len > seq_len:
|
||||
src = src[:seq_len]
|
||||
class SimpleUpsample(torch.nn.Module):
|
||||
"""
|
||||
A very simple form of upsampling that mostly just repeats the input, but
|
||||
also adds a position-specific bias.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_channels: int,
|
||||
upsample: int):
|
||||
super(SimpleUpsample, self).__init__()
|
||||
self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01)
|
||||
|
||||
def forward(self,
|
||||
src: Tensor) -> Tensor:
|
||||
"""
|
||||
x: (seq_len, batch_size, num_channels)
|
||||
Returns a tensor of shape
|
||||
( (seq_len*upsample), batch_size, num_channels)
|
||||
"""
|
||||
upsample = self.bias.shape[0]
|
||||
(seq_len, batch_size, num_channels) = src.shape
|
||||
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
|
||||
src = src + self.bias.unsqueeze(1)
|
||||
src = src.reshape(seq_len * upsample, batch_size, num_channels)
|
||||
return src
|
||||
|
||||
class SimpleCombiner(torch.nn.Module):
|
||||
"""
|
||||
A very simple way of combining 2 vectors of 2 different dims, via a
|
||||
learned weighted combination in the shared part of the dim.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
dim1: int,
|
||||
dim2: int):
|
||||
super(SimpleCombiner, self).__init__()
|
||||
assert dim2 > dim1
|
||||
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
|
||||
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
|
||||
|
||||
|
||||
def forward(self,
|
||||
src1: Tensor,
|
||||
src2: Tensor) -> Tensor:
|
||||
"""
|
||||
src1: (*, dim1)
|
||||
src2: (*, dim2)
|
||||
|
||||
Returns: a tensor of shape (*, dim2)
|
||||
"""
|
||||
assert src1.shape[:-1] == src2.shape[:-1]
|
||||
dim1 = src1.shape[-1]
|
||||
dim2 = src2.shape[-1]
|
||||
|
||||
weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True)
|
||||
weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True)
|
||||
weight = (weight1 + weight2).sigmoid()
|
||||
|
||||
src2_part1 = src2[...,:dim1]
|
||||
part1 = src1 * weight + src2_part1 * (1.0 - weight)
|
||||
part2 = src2[...,dim1:]
|
||||
return torch.cat((part1, part2), dim=-1)
|
||||
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user