mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
update decoding from checkpoint
This commit is contained in:
parent
0d6d8f9473
commit
e81545714a
@ -67,7 +67,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=49,
|
default=-1,
|
||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
@ -336,8 +336,13 @@ def main():
|
|||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
|
|
||||||
model_name_or_card = "seamlessM4T_medium"
|
model_name_or_card = "seamlessM4T_medium"
|
||||||
model_name_or_card = "seamlessM4T_large"
|
#model_name_or_card = "seamlessM4T_large"
|
||||||
model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
|
model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
|
||||||
|
del model.t2u_model
|
||||||
|
del model.text_encoder
|
||||||
|
del model.text_encoder_frontend
|
||||||
|
if params.epoch > 0:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user