mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Support averaging models with weight tying. (#333)
This commit is contained in:
parent
9a98e6ced6
commit
9aeea3e1af
@ -150,12 +150,25 @@ def average_checkpoints(
|
||||
n = len(filenames)
|
||||
|
||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||
|
||||
# Identify shared parameters. Two parameters are said to be shared
|
||||
# if they have the same data_ptr
|
||||
uniqued: Dict[int, str] = dict()
|
||||
|
||||
for k, v in avg.items():
|
||||
v_data_ptr = v.data_ptr()
|
||||
if v_data_ptr in uniqued:
|
||||
continue
|
||||
uniqued[v_data_ptr] = k
|
||||
|
||||
uniqued_names = list(uniqued.values())
|
||||
|
||||
for i in range(1, n):
|
||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||
for k in avg:
|
||||
for k in uniqued_names:
|
||||
avg[k] += state_dict[k]
|
||||
|
||||
for k in avg:
|
||||
for k in uniqued_names:
|
||||
if avg[k].is_floating_point():
|
||||
avg[k] /= n
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user