From 8993b022d22bac0114dfeffd3c9ae9f24192d1fb Mon Sep 17 00:00:00 2001 From: k2-fsa Date: Fri, 19 Sep 2025 09:32:33 +0800 Subject: [PATCH] Fix setting joiner dim --- egs/aishell/ASR/zipformer/train.py | 2 +- egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py | 2 +- egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py | 2 +- egs/mdcc/ASR/zipformer/train.py | 2 +- egs/tedlium3/ASR/zipformer/train.py | 2 +- egs/wenetspeech/ASR/zipformer/train.py | 2 +- egs/wenetspeech/KWS/zipformer/train.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index 3104665b0..02f62b210 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -639,7 +639,7 @@ def get_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), + encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, use_transducer=params.use_transducer, diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py index 93f7e1248..4fa1f44a6 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py @@ -651,7 +651,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), + encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py index 2a2c206aa..6b08e0b89 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py @@ -881,7 +881,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: text_encoder=text_encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), + encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, diff --git a/egs/mdcc/ASR/zipformer/train.py b/egs/mdcc/ASR/zipformer/train.py index 730db7718..d18cc5418 100755 --- a/egs/mdcc/ASR/zipformer/train.py +++ b/egs/mdcc/ASR/zipformer/train.py @@ -586,7 +586,7 @@ def get_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), + encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, ) diff --git a/egs/tedlium3/ASR/zipformer/train.py b/egs/tedlium3/ASR/zipformer/train.py index 14a44efb3..7f93be940 100755 --- a/egs/tedlium3/ASR/zipformer/train.py +++ b/egs/tedlium3/ASR/zipformer/train.py @@ -598,7 +598,7 @@ def get_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), + encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, diff --git a/egs/wenetspeech/ASR/zipformer/train.py b/egs/wenetspeech/ASR/zipformer/train.py index 25b16f632..9cbe96b5f 100755 --- a/egs/wenetspeech/ASR/zipformer/train.py +++ b/egs/wenetspeech/ASR/zipformer/train.py @@ -590,7 +590,7 @@ def get_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), + encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, ) diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py index 5d9d8de36..65d15b571 100755 --- a/egs/wenetspeech/KWS/zipformer/train.py +++ b/egs/wenetspeech/KWS/zipformer/train.py @@ -647,7 +647,7 @@ def get_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), + encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, )