add RNN and Conv2dSubsampling classes in lstm.py
This commit is contained in:
parent
7c9fcfa5c9
commit
2d53f2ef8b
@ -15,9 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from encoder_interface import EncoderInterface
|
||||
@ -25,12 +23,157 @@ from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
ScaledConv1d,
|
||||
ScaledConv2d,
|
||||
ScaledLinear,
|
||||
ScaledLSTM,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
from torch import nn
|
||||
|
||||
|
||||
class RNN(EncoderInterface):
|
||||
"""
|
||||
Args:
|
||||
num_features (int):
|
||||
Number of input features.
|
||||
subsampling_factor (int):
|
||||
Subsampling factor of encoder (convolution layers before lstm layers).
|
||||
d_model (int):
|
||||
Hidden dimension for lstm layers, also output dimension (default=512).
|
||||
dim_feedforward (int):
|
||||
Feedforward dimension (default=2048).
|
||||
num_encoder_layers (int):
|
||||
Number of encoder layers (default=12).
|
||||
dropout (float):
|
||||
Dropout rate (default=0.1).
|
||||
layer_dropout (float):
|
||||
Dropout value for model-level warmup (default=0.075).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
subsampling_factor: int,
|
||||
d_model: int = 512,
|
||||
dim_feedforward: int = 2048,
|
||||
num_encoder_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
) -> None:
|
||||
super(RNN, self).__init__()
|
||||
|
||||
self.num_features = num_features
|
||||
self.subsampling_factor = subsampling_factor
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
|
||||
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||
# to the shape (N, T//subsampling_factor, d_model).
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> T//subsampling_factor
|
||||
# (2) embedding: num_features -> d_model
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
|
||||
self.encoder_layers = num_encoder_layers
|
||||
self.d_model = d_model
|
||||
|
||||
encoder_layer = RNNEncoderLayer(
|
||||
d_model, dim_feedforward, dropout, layer_dropout
|
||||
)
|
||||
self.encoder = RNNEncoder(encoder_layer, num_encoder_layers)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is (N, T, C), where N is the batch size,
|
||||
T is the sequence length, C is the feature dimension.
|
||||
x_lens:
|
||||
A tensor of shape (N,), containing the number of frames in `x`
|
||||
before padding.
|
||||
warmup:
|
||||
A floating point value that gradually increases from 0 throughout
|
||||
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||
to turn modules on sequentially.
|
||||
|
||||
Returns:
|
||||
A tuple of 2 tensors:
|
||||
- embeddings: its shape is (N, T', d_model), where T' is the output
|
||||
sequence lengths.
|
||||
- lengths: a tensor of shape (batch_size,) containing the number of
|
||||
frames in `embeddings` before padding.
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
|
||||
#
|
||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
assert x.size(0) == lengths.max().item()
|
||||
|
||||
x = self.encoder(x, warmup)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||
return x, lengths
|
||||
|
||||
@torch.jit.export
|
||||
def get_init_state(self, device: torch.device) -> torch.Tensor:
|
||||
"""Get model initial state."""
|
||||
init_states = torch.zeros(
|
||||
(2, self.num_encoder_layers, self.d_model), device=device
|
||||
)
|
||||
return init_states
|
||||
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, states: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is (N, T, C), where N is the batch size,
|
||||
T is the sequence length, C is the feature dimension.
|
||||
x_lens:
|
||||
A tensor of shape (N,), containing the number of frames in `x`
|
||||
before padding.
|
||||
states:
|
||||
Its shape is (2, num_encoder_layers, N, E).
|
||||
states[0] and states[1] are cached hidden states and cell states for
|
||||
all layers, respectively.
|
||||
|
||||
Returns:
|
||||
A tuple of 3 tensors:
|
||||
- embeddings: its shape is (N, T', d_model), where T' is the output
|
||||
sequence lengths.
|
||||
- lengths: a tensor of shape (batch_size,) containing the number of
|
||||
frames in `embeddings` before padding.
|
||||
- updated states, with shape of (2, num_encoder_layers, N, E).
|
||||
"""
|
||||
assert not self.training
|
||||
assert states.shape == (
|
||||
2,
|
||||
self.num_encoder_layers,
|
||||
x.size(0),
|
||||
self.d_model,
|
||||
), states.shape
|
||||
|
||||
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
|
||||
#
|
||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
# we will cut off 1 frame on each side of encoder_embed output
|
||||
lengths -= 2
|
||||
|
||||
embed = self.encoder_embed(x)
|
||||
embed = embed[:, 1:-1, :]
|
||||
embed = embed.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
x, states = self.encoder.infer(embed, states)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||
return x, lengths, states
|
||||
|
||||
|
||||
class RNNEncoderLayer(nn.Module):
|
||||
@ -77,7 +220,7 @@ class RNNEncoderLayer(nn.Module):
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, src: Tensor, warmup: float = 1.0) -> Tensor:
|
||||
def forward(self, src: torch.Tensor, warmup: float = 1.0) -> torch.Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
@ -120,8 +263,8 @@ class RNNEncoderLayer(nn.Module):
|
||||
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self, src: Tensor, states: Tuple[Tensor]
|
||||
) -> Tuple[Tensor, Tuple[Tensor]]:
|
||||
self, src: torch.Tensor, states: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
@ -139,9 +282,9 @@ class RNNEncoderLayer(nn.Module):
|
||||
assert states.shape == (2, 1, src.size(1), src.size(2))
|
||||
|
||||
# lstm module
|
||||
# The required shapes of h_0 and c_0 are both (1, N, E)
|
||||
# The required shapes of h_0 and c_0 are both (1, N, E).
|
||||
src_lstm, new_states = self.lstm(src, states.unbind(dim=0))
|
||||
new_states = torch.stack(states, dim=0)
|
||||
new_states = torch.stack(new_states, dim=0)
|
||||
src = src + self.dropout(src_lstm)
|
||||
|
||||
# feed forward module
|
||||
@ -170,7 +313,7 @@ class RNNEncoder(nn.Module):
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
|
||||
def forward(self, src: Tensor, warmup: float = 1.0) -> Tensor:
|
||||
def forward(self, src: torch.Tensor, warmup: float = 1.0) -> torch.Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer in turn.
|
||||
|
||||
@ -192,8 +335,8 @@ class RNNEncoder(nn.Module):
|
||||
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self, src: Tensor, states: Tuple[Tensor]
|
||||
) -> Tuple[Tensor, Tuple[Tensor]]:
|
||||
self, src: torch.Tensor, states: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
@ -220,3 +363,97 @@ class RNNEncoder(nn.Module):
|
||||
new_states_list.append(new_states)
|
||||
|
||||
return output, torch.cat(new_states_list, dim=1)
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
||||
|
||||
It is based on
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
in_channels:
|
||||
Number of channels in. The input shape is (N, T, in_channels).
|
||||
Caution: It requires: T >=7, in_channels >=7
|
||||
out_channels
|
||||
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
|
||||
layer1_channels:
|
||||
Number of channels in layer1
|
||||
layer1_channels:
|
||||
Number of channels in layer2
|
||||
"""
|
||||
assert in_channels >= 7
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
ScaledConv2d(
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer1_channels,
|
||||
out_channels=layer2_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
)
|
||||
self.out = ScaledLinear(
|
||||
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
|
||||
)
|
||||
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||
# itself has learned scale, so the extra degree of freedom is not
|
||||
# needed.
|
||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||
# constrain median of output to be close to zero.
|
||||
self.out_balancer = ActivationBalancer(
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is (N, T, idim).
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||
"""
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
x = self.conv(x)
|
||||
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = self.out_norm(x)
|
||||
x = self.out_balancer(x)
|
||||
return x
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user