mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Merge branch 'model-averaging-shared-params' of https://github.com/csukuangfj/icefall into knowledge_base_1b_merge
This commit is contained in:
commit
551786b9bd
@ -35,7 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--use_fp16 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
|
@ -150,12 +150,25 @@ def average_checkpoints(
|
||||
n = len(filenames)
|
||||
|
||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||
|
||||
# Identify shared parameters. Two parameters are said to be shared
|
||||
# if they have the same data_ptr
|
||||
uniqued: Dict[int, str] = dict()
|
||||
|
||||
for k, v in avg.items():
|
||||
v_data_ptr = v.data_ptr()
|
||||
if v_data_ptr in uniqued:
|
||||
continue
|
||||
uniqued[v_data_ptr] = k
|
||||
|
||||
uniqued_names = list(uniqued.values())
|
||||
|
||||
for i in range(1, n):
|
||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||
for k in avg:
|
||||
for k in uniqued_names:
|
||||
avg[k] += state_dict[k]
|
||||
|
||||
for k in avg:
|
||||
for k in uniqued_names:
|
||||
if avg[k].is_floating_point():
|
||||
avg[k] /= n
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user