diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 80617847a..d15c44388 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -35,7 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ - --use_fp16 1 \ + --use-fp16 1 \ --exp-dir pruned_transducer_stateless2/exp \ --full-libri 1 \ --max-duration 550 diff --git a/egs/librispeech/ASR/transducer/decoder.py b/egs/librispeech/ASR/transducer/decoder.py index 7b529ac19..333fff300 100644 --- a/egs/librispeech/ASR/transducer/decoder.py +++ b/egs/librispeech/ASR/transducer/decoder.py @@ -89,9 +89,9 @@ class Decoder(nn.Module): - (h, c), containing the state information for LSTM layers. Both are of shape (num_layers, N, C) """ - embeding_out = self.embedding(y) - embeding_out = self.embedding_dropout(embeding_out) - rnn_out, (h, c) = self.rnn(embeding_out, states) + embedding_out = self.embedding(y) + embedding_out = self.embedding_dropout(embedding_out) + rnn_out, (h, c) = self.rnn(embedding_out, states) out = self.output_linear(rnn_out) return out, (h, c) diff --git a/egs/yesno/ASR/transducer/decoder.py b/egs/yesno/ASR/transducer/decoder.py index aa8a16845..7ae540d03 100644 --- a/egs/yesno/ASR/transducer/decoder.py +++ b/egs/yesno/ASR/transducer/decoder.py @@ -84,9 +84,9 @@ class Decoder(nn.Module): - (h, c), which contain the state information for RNN layers. Both are of shape (num_layers, N, C) """ - embeding_out = self.embedding(y) - embeding_out = self.embedding_dropout(embeding_out) - rnn_out, (h, c) = self.rnn(embeding_out, states) + embedding_out = self.embedding(y) + embedding_out = self.embedding_dropout(embedding_out) + rnn_out, (h, c) = self.rnn(embedding_out, states) out = self.output_linear(rnn_out) return out, (h, c) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index cc167292b..a4e71a148 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -150,12 +150,25 @@ def average_checkpoints( n = len(filenames) avg = torch.load(filenames[0], map_location=device)["model"] + + # Identify shared parameters. Two parameters are said to be shared + # if they have the same data_ptr + uniqued: Dict[int, str] = dict() + + for k, v in avg.items(): + v_data_ptr = v.data_ptr() + if v_data_ptr in uniqued: + continue + uniqued[v_data_ptr] = k + + uniqued_names = list(uniqued.values()) + for i in range(1, n): state_dict = torch.load(filenames[i], map_location=device)["model"] - for k in avg: + for k in uniqued_names: avg[k] += state_dict[k] - for k in avg: + for k in uniqued_names: if avg[k].is_floating_point(): avg[k] /= n else: