minor fix for param. names (#1495)

This commit is contained in:
zr_jin 2024-02-20 14:38:51 +08:00 committed by GitHub
parent e59fa38e86
commit 027302c902
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -159,7 +159,7 @@ class LmScorer(torch.nn.Module):
""" """
if lm_type == "rnn": if lm_type == "rnn":
model = RnnLmModel( model = RnnLmModel(
vocab_size=params.vocab_size, vocab_size=params.lm_vocab_size,
embedding_dim=params.rnn_lm_embedding_dim, embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim, hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers, num_layers=params.rnn_lm_num_layers,
@ -183,7 +183,7 @@ class LmScorer(torch.nn.Module):
elif lm_type == "transformer": elif lm_type == "transformer":
model = TransformerLM( model = TransformerLM(
vocab_size=params.vocab_size, vocab_size=params.lm_vocab_size,
d_model=params.transformer_lm_encoder_dim, d_model=params.transformer_lm_encoder_dim,
embedding_dim=params.transformer_lm_embedding_dim, embedding_dim=params.transformer_lm_embedding_dim,
dim_feedforward=params.transformer_lm_dim_feedforward, dim_feedforward=params.transformer_lm_dim_feedforward,