From 9a98e6ced6370e42f69a8d904ab66a481cfb4d6f Mon Sep 17 00:00:00 2001 From: pehonnet Date: Mon, 25 Apr 2022 12:51:53 +0200 Subject: [PATCH 1/2] fix fp16 option in example usage (#332) --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 80617847a..d15c44388 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -35,7 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ - --use_fp16 1 \ + --use-fp16 1 \ --exp-dir pruned_transducer_stateless2/exp \ --full-libri 1 \ --max-duration 550 From 0da522cc4cf9e9b9f0460a6c5e1953790a320c57 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 26 Apr 2022 13:14:08 +0800 Subject: [PATCH 2/2] Support averaging models with weight tying. --- icefall/checkpoint.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index cc167292b..a4e71a148 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -150,12 +150,25 @@ def average_checkpoints( n = len(filenames) avg = torch.load(filenames[0], map_location=device)["model"] + + # Identify shared parameters. Two parameters are said to be shared + # if they have the same data_ptr + uniqued: Dict[int, str] = dict() + + for k, v in avg.items(): + v_data_ptr = v.data_ptr() + if v_data_ptr in uniqued: + continue + uniqued[v_data_ptr] = k + + uniqued_names = list(uniqued.values()) + for i in range(1, n): state_dict = torch.load(filenames[i], map_location=device)["model"] - for k in avg: + for k in uniqued_names: avg[k] += state_dict[k] - for k in avg: + for k in uniqued_names: if avg[k].is_floating_point(): avg[k] /= n else: