fix loading

This commit is contained in:
Yuekai Zhang 2023-09-08 16:46:42 +08:00
parent 2a288fb9bf
commit d926585b10

View File

@ -58,7 +58,7 @@ from fairseq2.generation import (
SequenceToTextGenerator,
)
from seamless_communication.models.unity.model import UnitYX2TModel
from fairseq2.nn.embedding import Embedding
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -347,6 +347,8 @@ def main():
del model.t2u_model
del model.text_encoder
del model.text_encoder_frontend
model.text_decoder_frontend.embed = Embedding(num_embeddings=params.tokenzier.vocab_size, embedding_dim=1024 ,pad_idx=0, scaled=True)
model.final_proj = nn.Linear(1024, params.tokenizer.vocab_size)
if params.epoch > 0:
if params.avg > 1:
start = params.epoch - params.avg