diff --git a/egs/aishell/ASR/seamlessm4t/decode2.py b/egs/aishell/ASR/seamlessm4t/decode2.py index be70d041f..dc5e5ea95 100644 --- a/egs/aishell/ASR/seamlessm4t/decode2.py +++ b/egs/aishell/ASR/seamlessm4t/decode2.py @@ -58,7 +58,7 @@ from fairseq2.generation import ( SequenceToTextGenerator, ) from seamless_communication.models.unity.model import UnitYX2TModel - +from fairseq2.nn.embedding import Embedding def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -347,6 +347,8 @@ def main(): del model.t2u_model del model.text_encoder del model.text_encoder_frontend + model.text_decoder_frontend.embed = Embedding(num_embeddings=params.tokenzier.vocab_size, embedding_dim=1024 ,pad_idx=0, scaled=True) + model.final_proj = nn.Linear(1024, params.tokenizer.vocab_size) if params.epoch > 0: if params.avg > 1: start = params.epoch - params.avg