update decoding from checkpoint

This commit is contained in:
Yuekai Zhang 2023-09-07 17:39:37 +08:00
parent 0d6d8f9473
commit e81545714a

View File

@ -67,7 +67,7 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=49,
default=-1,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
@ -336,8 +336,13 @@ def main():
dtype = torch.float16
model_name_or_card = "seamlessM4T_medium"
model_name_or_card = "seamlessM4T_large"
#model_name_or_card = "seamlessM4T_large"
model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
del model.t2u_model
del model.text_encoder
del model.text_encoder_frontend
if params.epoch > 0:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])