mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
update decoding from checkpoint
This commit is contained in:
parent
0d6d8f9473
commit
e81545714a
@ -67,7 +67,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=49,
|
||||
default=-1,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
@ -336,8 +336,13 @@ def main():
|
||||
dtype = torch.float16
|
||||
|
||||
model_name_or_card = "seamlessM4T_medium"
|
||||
model_name_or_card = "seamlessM4T_large"
|
||||
#model_name_or_card = "seamlessM4T_large"
|
||||
model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
|
||||
del model.t2u_model
|
||||
del model.text_encoder
|
||||
del model.text_encoder_frontend
|
||||
if params.epoch > 0:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
|
Loading…
x
Reference in New Issue
Block a user