mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
hardcode bidirectional=False
This commit is contained in:
parent
2d53f2ef8b
commit
074bd7da71
@ -36,7 +36,7 @@ class RNN(EncoderInterface):
|
|||||||
num_features (int):
|
num_features (int):
|
||||||
Number of input features.
|
Number of input features.
|
||||||
subsampling_factor (int):
|
subsampling_factor (int):
|
||||||
Subsampling factor of encoder (convolution layers before lstm layers).
|
Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa
|
||||||
d_model (int):
|
d_model (int):
|
||||||
Hidden dimension for lstm layers, also output dimension (default=512).
|
Hidden dimension for lstm layers, also output dimension (default=512).
|
||||||
dim_feedforward (int):
|
dim_feedforward (int):
|
||||||
@ -52,7 +52,7 @@ class RNN(EncoderInterface):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
num_features: int,
|
||||||
subsampling_factor: int,
|
subsampling_factor: int = 4,
|
||||||
d_model: int = 512,
|
d_model: int = 512,
|
||||||
dim_feedforward: int = 2048,
|
dim_feedforward: int = 2048,
|
||||||
num_encoder_layers: int = 12,
|
num_encoder_layers: int = 12,
|
||||||
@ -457,3 +457,16 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
x = self.out_norm(x)
|
x = self.out_norm(x)
|
||||||
x = self.out_balancer(x)
|
x = self.out_balancer(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
feature_dim = 50
|
||||||
|
m = RNN(num_features=feature_dim, d_model=128)
|
||||||
|
batch_size = 5
|
||||||
|
seq_len = 20
|
||||||
|
# Just make sure the forward pass runs.
|
||||||
|
f = m(
|
||||||
|
torch.randn(batch_size, seq_len, feature_dim),
|
||||||
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
|
warmup=0.5,
|
||||||
|
)
|
||||||
|
@ -388,9 +388,9 @@ class ScaledLSTM(nn.LSTM):
|
|||||||
initial_speed: float = 1.0,
|
initial_speed: float = 1.0,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
# Hardcode num_layers=1 and proj_size=0 here
|
# Hardcode num_layers=1, bidirectional=False, proj_size=0 here
|
||||||
super(ScaledLSTM, self).__init__(
|
super(ScaledLSTM, self).__init__(
|
||||||
*args, num_layers=1, proj_size=0, **kwargs
|
*args, num_layers=1, bidirectional=False, proj_size=0, **kwargs
|
||||||
)
|
)
|
||||||
initial_scale = torch.tensor(initial_scale).log()
|
initial_scale = torch.tensor(initial_scale).log()
|
||||||
self._scales_names = []
|
self._scales_names = []
|
||||||
|
Loading…
x
Reference in New Issue
Block a user