mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Modify subsampling.py to make T'=T//4 strictly
This commit is contained in:
parent
022b0f3c55
commit
b0bce20e21
@ -0,0 +1,166 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where T' == T // 4.
|
||||
|
||||
It is based on
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int) -> None:
|
||||
"""
|
||||
Args:
|
||||
idim:
|
||||
Input dim. The input shape is (N, T, idim).
|
||||
Caution: It requires: T >= 4, idim >= 7
|
||||
odim:
|
||||
Output dim. The output shape is (N, T // 4, odim)
|
||||
"""
|
||||
assert idim >= 7
|
||||
super().__init__()
|
||||
self.conv_1 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.conv_2 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is (N, T, idim).
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape (N, T // 4, odim)
|
||||
"""
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1)
|
||||
# (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
x = nn.functional.pad(x, (0, 0, 0, 1), "constant", 0)
|
||||
# x is of shape (N, 1, T + 1, idim)
|
||||
x = self.conv_1(x)
|
||||
# Now x is of shape (N, odim, T // 2, (idim - 1) // 2)
|
||||
x = nn.functional.pad(x, (0, 0, 0, 1), "constant", 0)
|
||||
# x is of shape (N, odim, T // 2 + 1, (idim - 1) // 2)
|
||||
x = self.conv_2(x)
|
||||
# Now x is of shape (N, odim, T // 4, ((idim - 1) // 2 - 1) // 2)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape (N, T // 4, odim)
|
||||
return x
|
||||
|
||||
|
||||
class VggSubsampling(nn.Module):
|
||||
"""Trying to follow the setup described in the following paper:
|
||||
https://arxiv.org/pdf/1910.09799.pdf
|
||||
|
||||
This paper is not 100% explicit so I am guessing to some extent,
|
||||
and trying to compare with other VGG implementations.
|
||||
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where approximates T' = T//4.
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int) -> None:
|
||||
"""Construct a VggSubsampling object.
|
||||
|
||||
This uses 2 VGG blocks with 2 Conv2d layers each,
|
||||
subsampling its input by a factor of 4 in the time dimensions.
|
||||
|
||||
Args:
|
||||
idim:
|
||||
Input dim. The input shape is (N, T, idim).
|
||||
Caution: It requires: T >= 4, idim >= 4.
|
||||
odim:
|
||||
Output dim. The output shape is (N, T // 4, odim)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
cur_channels = 1
|
||||
layers = []
|
||||
block_dims = [32, 64]
|
||||
|
||||
# The decision to use padding=1 for the 1st convolution, then padding=0
|
||||
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
|
||||
# a back-compatibility concern so that the number of frames at the
|
||||
# output would be equal to:
|
||||
# (((T-1)//2)-1)//2.
|
||||
# We can consider changing this by using padding=1 on the
|
||||
# 2nd convolution, so the num-frames at the output would be T//4.
|
||||
for block_dim in block_dims:
|
||||
layers.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=cur_channels,
|
||||
out_channels=block_dim,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=1,
|
||||
)
|
||||
)
|
||||
layers.append(torch.nn.ReLU())
|
||||
layers.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=block_dim,
|
||||
out_channels=block_dim,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=1,
|
||||
)
|
||||
)
|
||||
layers.append(
|
||||
torch.nn.MaxPool2d(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=False
|
||||
)
|
||||
)
|
||||
cur_channels = block_dim
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.out = nn.Linear(block_dims[-1] * (idim // 4), odim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is (N, T, idim).
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape (N, T // 4, odim)
|
||||
"""
|
||||
x = x.unsqueeze(1)
|
||||
x = self.layers(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
return x
|
@ -0,0 +1,25 @@
|
||||
import torch
|
||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||
|
||||
|
||||
def test_conv2d_subsampling():
|
||||
B, idim, odim = 1, 80, 512
|
||||
model = Conv2dSubsampling(idim, odim)
|
||||
for t in range(4, 50):
|
||||
x = torch.randn(B, t, idim)
|
||||
outputs = model(x)
|
||||
assert outputs.shape == (B, t // 4, odim)
|
||||
|
||||
|
||||
def test_vgg_subsampling():
|
||||
B, idim, odim = 1, 80, 512
|
||||
model = VggSubsampling(idim, odim)
|
||||
for t in range(4, 50):
|
||||
x = torch.randn(B, t, idim)
|
||||
outputs = model(x)
|
||||
assert outputs.shape == (B, t // 4, odim)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_conv2d_subsampling()
|
||||
test_vgg_subsampling()
|
Loading…
x
Reference in New Issue
Block a user