minor fix of train.py

This commit is contained in:
yaozengwei 2022-05-09 16:41:46 +08:00
parent d0cea4f2f8
commit 6c5fd6f648

View File

@ -394,6 +394,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size,
)
return decoder
@ -829,6 +830,7 @@ def run(rank, world_size, args):
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)