diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode-giga.py index a715a2a5c..9ae17fd11 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode-giga.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode-giga.py @@ -18,36 +18,36 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless3/decode-giga.py \ +./pruned_transducer_stateless5/decode-giga.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 100 \ --decoding-method greedy_search (2) beam search -./pruned_transducer_stateless3/decode-giga.py \ +./pruned_transducer_stateless5/decode-giga.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 100 \ --decoding-method beam_search \ --beam-size 4 (3) modified beam search -./pruned_transducer_stateless3/decode-giga.py \ +./pruned_transducer_stateless5/decode-giga.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 100 \ --decoding-method modified_beam_search \ --beam-size 4 (4) fast beam search -./pruned_transducer_stateless3/decode-giga.py \ +./pruned_transducer_stateless5/decode-giga.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 1500 \ --decoding-method fast_beam_search \ --beam 4 \ @@ -128,7 +128,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless3/exp", + default="pruned_transducer_stateless5/exp", help="The experiment dir", ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 9a6b5a117..865709833 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -18,36 +18,36 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless3/decode.py \ +./pruned_transducer_stateless5/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 100 \ --decoding-method greedy_search (2) beam search -./pruned_transducer_stateless3/decode.py \ +./pruned_transducer_stateless5/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 100 \ --decoding-method beam_search \ --beam-size 4 (3) modified beam search -./pruned_transducer_stateless3/decode.py \ +./pruned_transducer_stateless5/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 100 \ --decoding-method modified_beam_search \ --beam-size 4 (4) fast beam search -./pruned_transducer_stateless3/decode.py \ +./pruned_transducer_stateless5/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 1500 \ --decoding-method fast_beam_search \ --beam 4 \ @@ -127,7 +127,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless3/exp", + default="pruned_transducer_stateless5/exp", help="The experiment dir", ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless5/test_model.py new file mode 100755 index 000000000..1162eb379 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/test_model.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless5/test_model.py +""" + +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + model = get_transducer_model(params) + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +def main(): + test_model() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 4966ea57f..dcedcfec6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -25,22 +25,22 @@ cd egs/librispeech/ASR/ ./prepare.sh ./prepare_giga_speech.sh -./pruned_transducer_stateless3/train.py \ +./pruned_transducer_stateless5/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ - --exp-dir pruned_transducer_stateless3/exp \ + --exp-dir pruned_transducer_stateless5/exp \ --full-libri 1 \ --max-duration 300 # For mix precision training: -./pruned_transducer_stateless3/train.py \ +./pruned_transducer_stateless5/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ --use_fp16 1 \ - --exp-dir pruned_transducer_stateless3/exp \ + --exp-dir pruned_transducer_stateless5/exp \ --full-libri 1 \ --max-duration 550 @@ -154,7 +154,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless3/exp", + default="pruned_transducer_stateless5/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -346,10 +346,11 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "encoder_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, + "encoder_dim": 256, + "nhead": 4, + "dim_feedforward": 1024, + "num_encoder_layers": 18, + "knowledge_D": 512, # parameters for decoder "decoder_dim": 512, # parameters for joiner @@ -372,6 +373,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, + knowledge_D=params.knowledge_D, ) return encoder