add RNN and Conv2dSubsampling classes in lstm.py

This commit is contained in:
yaozengwei 2022-07-17 12:59:27 +08:00
parent 7c9fcfa5c9
commit 2d53f2ef8b

View File

@ -15,9 +15,7 @@
# limitations under the License.
import copy
import math
import warnings
from typing import List, Optional, Tuple
from typing import Tuple
import torch
from encoder_interface import EncoderInterface
@ -25,12 +23,157 @@ from scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv1d,
ScaledConv2d,
ScaledLinear,
ScaledLSTM,
)
from torch import Tensor, nn
from torch import nn
class RNN(EncoderInterface):
"""
Args:
num_features (int):
Number of input features.
subsampling_factor (int):
Subsampling factor of encoder (convolution layers before lstm layers).
d_model (int):
Hidden dimension for lstm layers, also output dimension (default=512).
dim_feedforward (int):
Feedforward dimension (default=2048).
num_encoder_layers (int):
Number of encoder layers (default=12).
dropout (float):
Dropout rate (default=0.1).
layer_dropout (float):
Dropout value for model-level warmup (default=0.075).
"""
def __init__(
self,
num_features: int,
subsampling_factor: int,
d_model: int = 512,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
dropout: float = 0.1,
layer_dropout: float = 0.075,
) -> None:
super(RNN, self).__init__()
self.num_features = num_features
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_layers = num_encoder_layers
self.d_model = d_model
encoder_layer = RNNEncoderLayer(
d_model, dim_feedforward, dropout, layer_dropout
)
self.encoder = RNNEncoder(encoder_layer, num_encoder_layers)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (N, T, C), where N is the batch size,
T is the sequence length, C is the feature dimension.
x_lens:
A tensor of shape (N,), containing the number of frames in `x`
before padding.
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
Returns:
A tuple of 2 tensors:
- embeddings: its shape is (N, T', d_model), where T' is the output
sequence lengths.
- lengths: a tensor of shape (batch_size,) containing the number of
frames in `embeddings` before padding.
"""
x = self.encoder_embed(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item()
x = self.encoder(x, warmup)
x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
return x, lengths
@torch.jit.export
def get_init_state(self, device: torch.device) -> torch.Tensor:
"""Get model initial state."""
init_states = torch.zeros(
(2, self.num_encoder_layers, self.d_model), device=device
)
return init_states
@torch.jit.export
def infer(
self, x: torch.Tensor, x_lens: torch.Tensor, states: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (N, T, C), where N is the batch size,
T is the sequence length, C is the feature dimension.
x_lens:
A tensor of shape (N,), containing the number of frames in `x`
before padding.
states:
Its shape is (2, num_encoder_layers, N, E).
states[0] and states[1] are cached hidden states and cell states for
all layers, respectively.
Returns:
A tuple of 3 tensors:
- embeddings: its shape is (N, T', d_model), where T' is the output
sequence lengths.
- lengths: a tensor of shape (batch_size,) containing the number of
frames in `embeddings` before padding.
- updated states, with shape of (2, num_encoder_layers, N, E).
"""
assert not self.training
assert states.shape == (
2,
self.num_encoder_layers,
x.size(0),
self.d_model,
), states.shape
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
# we will cut off 1 frame on each side of encoder_embed output
lengths -= 2
embed = self.encoder_embed(x)
embed = embed[:, 1:-1, :]
embed = embed.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
x, states = self.encoder.infer(embed, states)
x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
return x, lengths, states
class RNNEncoderLayer(nn.Module):
@ -77,7 +220,7 @@ class RNNEncoderLayer(nn.Module):
)
self.dropout = nn.Dropout(dropout)
def forward(self, src: Tensor, warmup: float = 1.0) -> Tensor:
def forward(self, src: torch.Tensor, warmup: float = 1.0) -> torch.Tensor:
"""
Pass the input through the encoder layer.
@ -120,8 +263,8 @@ class RNNEncoderLayer(nn.Module):
@torch.jit.export
def infer(
self, src: Tensor, states: Tuple[Tensor]
) -> Tuple[Tensor, Tuple[Tensor]]:
self, src: torch.Tensor, states: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Pass the input through the encoder layer.
@ -139,9 +282,9 @@ class RNNEncoderLayer(nn.Module):
assert states.shape == (2, 1, src.size(1), src.size(2))
# lstm module
# The required shapes of h_0 and c_0 are both (1, N, E)
# The required shapes of h_0 and c_0 are both (1, N, E).
src_lstm, new_states = self.lstm(src, states.unbind(dim=0))
new_states = torch.stack(states, dim=0)
new_states = torch.stack(new_states, dim=0)
src = src + self.dropout(src_lstm)
# feed forward module
@ -170,7 +313,7 @@ class RNNEncoder(nn.Module):
)
self.num_layers = num_layers
def forward(self, src: Tensor, warmup: float = 1.0) -> Tensor:
def forward(self, src: torch.Tensor, warmup: float = 1.0) -> torch.Tensor:
"""
Pass the input through the encoder layer in turn.
@ -192,8 +335,8 @@ class RNNEncoder(nn.Module):
@torch.jit.export
def infer(
self, src: Tensor, states: Tuple[Tensor]
) -> Tuple[Tensor, Tuple[Tensor]]:
self, src: torch.Tensor, states: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Pass the input through the encoder layer.
@ -220,3 +363,97 @@ class RNNEncoder(nn.Module):
new_states_list.append(new_states)
return output, torch.cat(new_states_list, dim=1)
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(
self,
in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
) -> None:
"""
Args:
in_channels:
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >=7, in_channels >=7
out_channels
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
Number of channels in layer2
"""
assert in_channels >= 7
super().__init__()
self.conv = nn.Sequential(
ScaledConv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=1,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
)
self.out = ScaledLinear(
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
)
# set learn_eps=False because out_norm is preceded by `out`, and `out`
# itself has learned scale, so the extra degree of freedom is not
# needed.
self.out_norm = BasicNorm(out_channels, learn_eps=False)
# constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
Returns:
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out_norm(x)
x = self.out_balancer(x)
return x