diff --git a/egs/aishell/ASR/seamlessm4t/decode.py b/egs/aishell/ASR/seamlessm4t/decode.py index 9ed11eb55..255abba00 100755 --- a/egs/aishell/ASR/seamlessm4t/decode.py +++ b/egs/aishell/ASR/seamlessm4t/decode.py @@ -67,7 +67,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=49, + default=-1, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) @@ -336,8 +336,13 @@ def main(): dtype = torch.float16 model_name_or_card = "seamlessM4T_medium" - model_name_or_card = "seamlessM4T_large" + #model_name_or_card = "seamlessM4T_large" model = load_unity_model(model_name_or_card, device=device, dtype=dtype) + del model.t2u_model + del model.text_encoder + del model.text_encoder_frontend + if params.epoch > 0: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()])