minor fix for param. names (#1495)

This commit is contained in:
zr_jin 2024-02-20 14:38:51 +08:00 committed by GitHub
parent e59fa38e86
commit 027302c902
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -159,7 +159,7 @@ class LmScorer(torch.nn.Module):
"""
if lm_type == "rnn":
model = RnnLmModel(
vocab_size=params.vocab_size,
vocab_size=params.lm_vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
@ -183,7 +183,7 @@ class LmScorer(torch.nn.Module):
elif lm_type == "transformer":
model = TransformerLM(
vocab_size=params.vocab_size,
vocab_size=params.lm_vocab_size,
d_model=params.transformer_lm_encoder_dim,
embedding_dim=params.transformer_lm_embedding_dim,
dim_feedforward=params.transformer_lm_dim_feedforward,