Support averaging models with weight tying. (#333)

This commit is contained in:
Fangjun Kuang 2022-04-26 13:32:03 +08:00 committed by GitHub
parent 9a98e6ced6
commit 9aeea3e1af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -150,12 +150,25 @@ def average_checkpoints(
n = len(filenames)
avg = torch.load(filenames[0], map_location=device)["model"]
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
uniqued: Dict[int, str] = dict()
for k, v in avg.items():
v_data_ptr = v.data_ptr()
if v_data_ptr in uniqued:
continue
uniqued[v_data_ptr] = k
uniqued_names = list(uniqued.values())
for i in range(1, n):
state_dict = torch.load(filenames[i], map_location=device)["model"]
for k in avg:
for k in uniqued_names:
avg[k] += state_dict[k]
for k in avg:
for k in uniqued_names:
if avg[k].is_floating_point():
avg[k] /= n
else: