Merge branch 'model-averaging-shared-params' of https://github.com/csukuangfj/icefall into knowledge_base_1b_merge

This commit is contained in:
Daniel Povey 2022-04-26 13:18:09 +08:00
commit 551786b9bd
2 changed files with 16 additions and 3 deletions

View File

@ -35,7 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--use_fp16 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \
--max-duration 550

View File

@ -150,12 +150,25 @@ def average_checkpoints(
n = len(filenames)
avg = torch.load(filenames[0], map_location=device)["model"]
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
uniqued: Dict[int, str] = dict()
for k, v in avg.items():
v_data_ptr = v.data_ptr()
if v_data_ptr in uniqued:
continue
uniqued[v_data_ptr] = k
uniqued_names = list(uniqued.values())
for i in range(1, n):
state_dict = torch.load(filenames[i], map_location=device)["model"]
for k in avg:
for k in uniqued_names:
avg[k] += state_dict[k]
for k in avg:
for k in uniqued_names:
if avg[k].is_floating_point():
avg[k] /= n
else: