hardcode bidirectional=False

This commit is contained in:
yaozengwei 2022-07-17 15:31:25 +08:00
parent 2d53f2ef8b
commit 074bd7da71
2 changed files with 17 additions and 4 deletions

View File

@ -36,7 +36,7 @@ class RNN(EncoderInterface):
num_features (int): num_features (int):
Number of input features. Number of input features.
subsampling_factor (int): subsampling_factor (int):
Subsampling factor of encoder (convolution layers before lstm layers). Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa
d_model (int): d_model (int):
Hidden dimension for lstm layers, also output dimension (default=512). Hidden dimension for lstm layers, also output dimension (default=512).
dim_feedforward (int): dim_feedforward (int):
@ -52,7 +52,7 @@ class RNN(EncoderInterface):
def __init__( def __init__(
self, self,
num_features: int, num_features: int,
subsampling_factor: int, subsampling_factor: int = 4,
d_model: int = 512, d_model: int = 512,
dim_feedforward: int = 2048, dim_feedforward: int = 2048,
num_encoder_layers: int = 12, num_encoder_layers: int = 12,
@ -457,3 +457,16 @@ class Conv2dSubsampling(nn.Module):
x = self.out_norm(x) x = self.out_norm(x)
x = self.out_balancer(x) x = self.out_balancer(x)
return x return x
if __name__ == "__main__":
feature_dim = 50
m = RNN(num_features=feature_dim, d_model=128)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
f = m(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5,
)

View File

@ -388,9 +388,9 @@ class ScaledLSTM(nn.LSTM):
initial_speed: float = 1.0, initial_speed: float = 1.0,
**kwargs **kwargs
): ):
# Hardcode num_layers=1 and proj_size=0 here # Hardcode num_layers=1, bidirectional=False, proj_size=0 here
super(ScaledLSTM, self).__init__( super(ScaledLSTM, self).__init__(
*args, num_layers=1, proj_size=0, **kwargs *args, num_layers=1, bidirectional=False, proj_size=0, **kwargs
) )
initial_scale = torch.tensor(initial_scale).log() initial_scale = torch.tensor(initial_scale).log()
self._scales_names = [] self._scales_names = []