minor fix for param. names

This commit is contained in:
jinzr 2024-02-08 09:35:48 +08:00
parent a813186f64
commit eedc6b2cec

View File

@ -159,7 +159,7 @@ class LmScorer(torch.nn.Module):
""" """
if lm_type == "rnn": if lm_type == "rnn":
model = RnnLmModel( model = RnnLmModel(
vocab_size=params.vocab_size, vocab_size=params.lm_vocab_size,
embedding_dim=params.rnn_lm_embedding_dim, embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim, hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers, num_layers=params.rnn_lm_num_layers,
@ -183,7 +183,7 @@ class LmScorer(torch.nn.Module):
elif lm_type == "transformer": elif lm_type == "transformer":
model = TransformerLM( model = TransformerLM(
vocab_size=params.vocab_size, vocab_size=params.lm_vocab_size,
d_model=params.transformer_lm_encoder_dim, d_model=params.transformer_lm_encoder_dim,
embedding_dim=params.transformer_lm_embedding_dim, embedding_dim=params.transformer_lm_embedding_dim,
dim_feedforward=params.transformer_lm_dim_feedforward, dim_feedforward=params.transformer_lm_dim_feedforward,