Minor fixes.

This commit is contained in:
Fangjun Kuang 2022-04-23 16:53:01 +08:00
parent 51cc6486cd
commit 85ac3a8000
4 changed files with 100 additions and 9 deletions

View File

@ -0,0 +1,44 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless2/test_model.py
"""
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -296,7 +296,7 @@ class ConformerEncoder(nn.Module):
assert num_layers - 1 not in aux_layers assert num_layers - 1 not in aux_layers
self.aux_layers = set(aux_layers + [num_layers - 1]) self.aux_layers = set(aux_layers + [num_layers - 1])
num_channels = encoder_layer.norm_final.weight.numel() num_channels = encoder_layer.norm_final.num_channels
self.combiner = RandomCombine( self.combiner = RandomCombine(
num_inputs=len(self.aux_layers), num_inputs=len(self.aux_layers),
num_channels=num_channels, num_channels=num_channels,

View File

@ -0,0 +1,47 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless4/test_model.py
"""
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = 24
params.dim_feedforward = 1536 # 384 * 4
params.encoder_dim = 384
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -21,22 +21,22 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless2/train.py \ ./pruned_transducer_stateless4/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless4/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 --max-duration 300
# For mix precision training: # For mix precision training:
./pruned_transducer_stateless2/train.py \ ./pruned_transducer_stateless4/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--use_fp16 1 \ --use_fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless4/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 550 --max-duration 550
@ -138,7 +138,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="pruned_transducer_stateless4/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -322,10 +322,10 @@ def get_params() -> AttributeDict:
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
"encoder_dim": 512, "encoder_dim": 384,
"nhead": 8, "nhead": 8,
"dim_feedforward": 2048, "dim_feedforward": 1536,
"num_encoder_layers": 12, "num_encoder_layers": 24,
# parameters for decoder # parameters for decoder
"decoder_dim": 512, "decoder_dim": 512,
# parameters for joiner # parameters for joiner