From e81545714af802ff789208b058e96e3078a50aa4 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Thu, 7 Sep 2023 17:39:37 +0800 Subject: [PATCH] update decoding from checkpoint --- egs/aishell/ASR/seamlessm4t/decode.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/egs/aishell/ASR/seamlessm4t/decode.py b/egs/aishell/ASR/seamlessm4t/decode.py index 9ed11eb55..255abba00 100755 --- a/egs/aishell/ASR/seamlessm4t/decode.py +++ b/egs/aishell/ASR/seamlessm4t/decode.py @@ -67,7 +67,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=49, + default=-1, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) @@ -336,8 +336,13 @@ def main(): dtype = torch.float16 model_name_or_card = "seamlessM4T_medium" - model_name_or_card = "seamlessM4T_large" + #model_name_or_card = "seamlessM4T_large" model = load_unity_model(model_name_or_card, device=device, dtype=dtype) + del model.t2u_model + del model.text_encoder + del model.text_encoder_frontend + if params.epoch > 0: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()])