From 9aeea3e1af2b288d0cd186a64a4ec2e37eccedc8 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 26 Apr 2022 13:32:03 +0800 Subject: [PATCH] Support averaging models with weight tying. (#333) --- icefall/checkpoint.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index cc167292b..a4e71a148 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -150,12 +150,25 @@ def average_checkpoints( n = len(filenames) avg = torch.load(filenames[0], map_location=device)["model"] + + # Identify shared parameters. Two parameters are said to be shared + # if they have the same data_ptr + uniqued: Dict[int, str] = dict() + + for k, v in avg.items(): + v_data_ptr = v.data_ptr() + if v_data_ptr in uniqued: + continue + uniqued[v_data_ptr] = k + + uniqued_names = list(uniqued.values()) + for i in range(1, n): state_dict = torch.load(filenames[i], map_location=device)["model"] - for k in avg: + for k in uniqued_names: avg[k] += state_dict[k] - for k in avg: + for k in uniqued_names: if avg[k].is_floating_point(): avg[k] /= n else: