mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Support averaging models with weight tying. (#333)
This commit is contained in:
parent
9a98e6ced6
commit
9aeea3e1af
@ -150,12 +150,25 @@ def average_checkpoints(
|
|||||||
n = len(filenames)
|
n = len(filenames)
|
||||||
|
|
||||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||||
|
|
||||||
|
# Identify shared parameters. Two parameters are said to be shared
|
||||||
|
# if they have the same data_ptr
|
||||||
|
uniqued: Dict[int, str] = dict()
|
||||||
|
|
||||||
|
for k, v in avg.items():
|
||||||
|
v_data_ptr = v.data_ptr()
|
||||||
|
if v_data_ptr in uniqued:
|
||||||
|
continue
|
||||||
|
uniqued[v_data_ptr] = k
|
||||||
|
|
||||||
|
uniqued_names = list(uniqued.values())
|
||||||
|
|
||||||
for i in range(1, n):
|
for i in range(1, n):
|
||||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||||
for k in avg:
|
for k in uniqued_names:
|
||||||
avg[k] += state_dict[k]
|
avg[k] += state_dict[k]
|
||||||
|
|
||||||
for k in avg:
|
for k in uniqued_names:
|
||||||
if avg[k].is_floating_point():
|
if avg[k].is_floating_point():
|
||||||
avg[k] /= n
|
avg[k] /= n
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user