update decoding from checkpoint

This commit is contained in:
Yuekai Zhang 2023-09-07 17:39:37 +08:00
parent 0d6d8f9473
commit e81545714a

View File

@ -67,7 +67,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=49, default=-1,
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
@ -336,8 +336,13 @@ def main():
dtype = torch.float16 dtype = torch.float16
model_name_or_card = "seamlessM4T_medium" model_name_or_card = "seamlessM4T_medium"
model_name_or_card = "seamlessM4T_large" #model_name_or_card = "seamlessM4T_large"
model = load_unity_model(model_name_or_card, device=device, dtype=dtype) model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
del model.t2u_model
del model.text_encoder
del model.text_encoder_frontend
if params.epoch > 0:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device) model.to(device)
model.eval() model.eval()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])