From d926585b10e48a6707b096b532672323ba522638 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Fri, 8 Sep 2023 16:46:42 +0800 Subject: [PATCH] fix loading --- egs/aishell/ASR/seamlessm4t/decode2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/aishell/ASR/seamlessm4t/decode2.py b/egs/aishell/ASR/seamlessm4t/decode2.py index be70d041f..dc5e5ea95 100644 --- a/egs/aishell/ASR/seamlessm4t/decode2.py +++ b/egs/aishell/ASR/seamlessm4t/decode2.py @@ -58,7 +58,7 @@ from fairseq2.generation import ( SequenceToTextGenerator, ) from seamless_communication.models.unity.model import UnitYX2TModel - +from fairseq2.nn.embedding import Embedding def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -347,6 +347,8 @@ def main(): del model.t2u_model del model.text_encoder del model.text_encoder_frontend + model.text_decoder_frontend.embed = Embedding(num_embeddings=params.tokenzier.vocab_size, embedding_dim=1024 ,pad_idx=0, scaled=True) + model.final_proj = nn.Linear(1024, params.tokenizer.vocab_size) if params.epoch > 0: if params.avg > 1: start = params.epoch - params.avg