From 074bd7da717b0aa3a2ffd217d315401e23d9a8d1 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 17 Jul 2022 15:31:25 +0800 Subject: [PATCH] hardcode bidirectional=False --- .../ASR/lstm_transducer_stateless/lstm.py | 17 +++++++++++++++-- .../ASR/pruned_transducer_stateless2/scaling.py | 4 ++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index e180d9ec6..52424c0bb 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -36,7 +36,7 @@ class RNN(EncoderInterface): num_features (int): Number of input features. subsampling_factor (int): - Subsampling factor of encoder (convolution layers before lstm layers). + Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa d_model (int): Hidden dimension for lstm layers, also output dimension (default=512). dim_feedforward (int): @@ -52,7 +52,7 @@ class RNN(EncoderInterface): def __init__( self, num_features: int, - subsampling_factor: int, + subsampling_factor: int = 4, d_model: int = 512, dim_feedforward: int = 2048, num_encoder_layers: int = 12, @@ -457,3 +457,16 @@ class Conv2dSubsampling(nn.Module): x = self.out_norm(x) x = self.out_balancer(x) return x + + +if __name__ == "__main__": + feature_dim = 50 + m = RNN(num_features=feature_dim, d_model=128) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = m( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup=0.5, + ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 54f4a53c5..d0c16cd1e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -388,9 +388,9 @@ class ScaledLSTM(nn.LSTM): initial_speed: float = 1.0, **kwargs ): - # Hardcode num_layers=1 and proj_size=0 here + # Hardcode num_layers=1, bidirectional=False, proj_size=0 here super(ScaledLSTM, self).__init__( - *args, num_layers=1, proj_size=0, **kwargs + *args, num_layers=1, bidirectional=False, proj_size=0, **kwargs ) initial_scale = torch.tensor(initial_scale).log() self._scales_names = []