minor fix for param. names

This commit is contained in:
jinzr 2024-02-08 09:35:48 +08:00
parent a813186f64
commit eedc6b2cec

View File

@ -159,7 +159,7 @@ class LmScorer(torch.nn.Module):
"""
if lm_type == "rnn":
model = RnnLmModel(
vocab_size=params.vocab_size,
vocab_size=params.lm_vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
@ -183,7 +183,7 @@ class LmScorer(torch.nn.Module):
elif lm_type == "transformer":
model = TransformerLM(
vocab_size=params.vocab_size,
vocab_size=params.lm_vocab_size,
d_model=params.transformer_lm_encoder_dim,
embedding_dim=params.transformer_lm_embedding_dim,
dim_feedforward=params.transformer_lm_dim_feedforward,