Update encoder model parameters.

This commit is contained in:
Fangjun Kuang 2022-05-05 21:15:13 +08:00
parent eac839478b
commit e38c6aa7fa
4 changed files with 73 additions and 27 deletions

View File

@ -18,36 +18,36 @@
""" """
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless3/decode-giga.py \ ./pruned_transducer_stateless5/decode-giga.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search
./pruned_transducer_stateless3/decode-giga.py \ ./pruned_transducer_stateless5/decode-giga.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless3/decode-giga.py \ ./pruned_transducer_stateless5/decode-giga.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless3/decode-giga.py \ ./pruned_transducer_stateless5/decode-giga.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 1500 \ --max-duration 1500 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 4 \
@ -128,7 +128,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless3/exp", default="pruned_transducer_stateless5/exp",
help="The experiment dir", help="The experiment dir",
) )

View File

@ -18,36 +18,36 @@
""" """
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless3/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search
./pruned_transducer_stateless3/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless3/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless3/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 1500 \ --max-duration 1500 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 4 \
@ -127,7 +127,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless3/exp", default="pruned_transducer_stateless5/exp",
help="The experiment dir", help="The experiment dir",
) )

View File

@ -0,0 +1,44 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless5/test_model.py
"""
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -25,22 +25,22 @@ cd egs/librispeech/ASR/
./prepare.sh ./prepare.sh
./prepare_giga_speech.sh ./prepare_giga_speech.sh
./pruned_transducer_stateless3/train.py \ ./pruned_transducer_stateless5/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir pruned_transducer_stateless3/exp \ --exp-dir pruned_transducer_stateless5/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 --max-duration 300
# For mix precision training: # For mix precision training:
./pruned_transducer_stateless3/train.py \ ./pruned_transducer_stateless5/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--use_fp16 1 \ --use_fp16 1 \
--exp-dir pruned_transducer_stateless3/exp \ --exp-dir pruned_transducer_stateless5/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 550 --max-duration 550
@ -154,7 +154,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless3/exp", default="pruned_transducer_stateless5/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -346,10 +346,11 @@ def get_params() -> AttributeDict:
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
"encoder_dim": 512, "encoder_dim": 256,
"nhead": 8, "nhead": 4,
"dim_feedforward": 2048, "dim_feedforward": 1024,
"num_encoder_layers": 12, "num_encoder_layers": 18,
"knowledge_D": 512,
# parameters for decoder # parameters for decoder
"decoder_dim": 512, "decoder_dim": 512,
# parameters for joiner # parameters for joiner
@ -372,6 +373,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
nhead=params.nhead, nhead=params.nhead,
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
knowledge_D=params.knowledge_D,
) )
return encoder return encoder