Fix setting joiner dim (#2027)

Fixes incorrect computation of encoder_dim when encoder_dim is a comma-separated list of integers by ensuring numeric (not lexicographic) max is used.

Fixes #2018

- Replace int(max(params.encoder_dim.split(","))) (lexicographic max on strings) with max(_to_int_tuple(params.encoder_dim)) (numeric max).
- Apply the fix consistently across all affected training scripts.
This commit is contained in:
Fangjun Kuang 2025-09-19 09:42:41 +08:00 committed by GitHub
parent 0c7ce5256f
commit 63563d16d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 7 additions and 7 deletions

View File

@ -639,7 +639,7 @@ def get_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(","))),
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size,
use_transducer=params.use_transducer,

View File

@ -651,7 +651,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(","))),
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,

View File

@ -881,7 +881,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
text_encoder=text_encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(","))),
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,

View File

@ -586,7 +586,7 @@ def get_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(","))),
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size,
)

View File

@ -598,7 +598,7 @@ def get_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(","))),
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,

View File

@ -590,7 +590,7 @@ def get_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(","))),
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size,
)

View File

@ -647,7 +647,7 @@ def get_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(","))),
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size,
)