from local

This commit is contained in:
dohe0342 2022-12-20 16:16:46 +09:00
parent 225e8b7beb
commit 6c82d84694
2 changed files with 77 additions and 74 deletions

View File

@ -660,82 +660,85 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
if not params.use_averaged_model: if params.model_path:
if params.iter > 0: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if not params.use_averaged_model:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ if params.iter > 0:
: params.avg + 1 filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
] : params.avg
if len(filenames) == 0: ]
raise ValueError( if len(filenames) == 0:
f"No checkpoints found for" raise ValueError(
f" --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
) f" --iter {params.iter}, --avg {params.avg}"
elif len(filenames) < params.avg + 1: )
raise ValueError( elif len(filenames) < params.avg:
f"Not enough checkpoints ({len(filenames)}) found for" raise ValueError(
f" --iter {params.iter}, --avg {params.avg}" f"Not enough checkpoints ({len(filenames)}) found for"
) f" --iter {params.iter}, --avg {params.avg}"
filename_start = filenames[-1] )
filename_end = filenames[0] logging.info(f"averaging {filenames}")
logging.info( model.to(device)
"Calculating the averaged model over iteration checkpoints" model.load_state_dict(average_checkpoints(filenames, device=device))
f" from {filename_start} (excluded) to {filename_end}" elif params.avg == 1:
) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device) else:
model.load_state_dict( start = params.epoch - params.avg + 1
average_checkpoints_with_averaged_model( filenames = []
filename_start=filename_start, for i in range(start, params.epoch + 1):
filename_end=filename_end, if i >= 1:
device=device, filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
) logging.info(f"averaging {filenames}")
) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
assert params.avg > 0, params.avg if params.iter > 0:
start = params.epoch - params.avg filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
assert start >= 1, start : params.avg + 1
filename_start = f"{params.exp_dir}/epoch-{start}.pt" ]
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" if len(filenames) == 0:
logging.info( raise ValueError(
f"Calculating the averaged model over epoch range from " f"No checkpoints found for"
f"{start} (excluded) to {params.epoch}" f" --iter {params.iter}, --avg {params.avg}"
) )
model.to(device) elif len(filenames) < params.avg + 1:
model.load_state_dict( raise ValueError(
average_checkpoints_with_averaged_model( f"Not enough checkpoints ({len(filenames)}) found for"
filename_start=filename_start, f" --iter {params.iter}, --avg {params.avg}"
filename_end=filename_end, )
device=device, filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
) )
)
model.to(device) model.to(device)
model.eval() model.eval()