From 85ac3a8000fd33d59064418a2f18f33603fdda5e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 23 Apr 2022 16:53:01 +0800 Subject: [PATCH] Minor fixes. --- .../test_model.py | 44 +++++++++++++++++ .../pruned_transducer_stateless4/conformer.py | 2 +- .../test_model.py | 47 +++++++++++++++++++ .../ASR/pruned_transducer_stateless4/train.py | 16 +++---- 4 files changed, 100 insertions(+), 9 deletions(-) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py new file mode 100755 index 000000000..e5f71ef2d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless2/test_model.py +""" + +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + model = get_transducer_model(params) + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +def main(): + test_model() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py index 43c4f468b..9a4dc61c5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py @@ -296,7 +296,7 @@ class ConformerEncoder(nn.Module): assert num_layers - 1 not in aux_layers self.aux_layers = set(aux_layers + [num_layers - 1]) - num_channels = encoder_layer.norm_final.weight.numel() + num_channels = encoder_layer.norm_final.num_channels self.combiner = RandomCombine( num_inputs=len(self.aux_layers), num_channels=num_channels, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py new file mode 100755 index 000000000..43f84e5c7 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless4/test_model.py +""" + +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = 24 + params.dim_feedforward = 1536 # 384 * 4 + params.encoder_dim = 384 + model = get_transducer_model(params) + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +def main(): + test_model() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 80617847a..31617c3b0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -21,22 +21,22 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless2/train.py \ +./pruned_transducer_stateless4/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ + --exp-dir pruned_transducer_stateless4/exp \ --full-libri 1 \ --max-duration 300 # For mix precision training: -./pruned_transducer_stateless2/train.py \ +./pruned_transducer_stateless4/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ --use_fp16 1 \ - --exp-dir pruned_transducer_stateless2/exp \ + --exp-dir pruned_transducer_stateless4/exp \ --full-libri 1 \ --max-duration 550 @@ -138,7 +138,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="pruned_transducer_stateless4/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -322,10 +322,10 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "encoder_dim": 512, + "encoder_dim": 384, "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, + "dim_feedforward": 1536, + "num_encoder_layers": 24, # parameters for decoder "decoder_dim": 512, # parameters for joiner