From d766dc5aeea1a8aefab033e581948b07c4ac4bc0 Mon Sep 17 00:00:00 2001 From: whsqkaak Date: Fri, 22 Apr 2022 16:54:59 +0900 Subject: [PATCH 1/3] Fix some typos. (#329) --- egs/librispeech/ASR/transducer/decoder.py | 6 +++--- egs/librispeech/ASR/transducer_lstm/decoder.py | 6 +++--- egs/yesno/ASR/transducer/decoder.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/transducer/decoder.py b/egs/librispeech/ASR/transducer/decoder.py index 7b529ac19..333fff300 100644 --- a/egs/librispeech/ASR/transducer/decoder.py +++ b/egs/librispeech/ASR/transducer/decoder.py @@ -89,9 +89,9 @@ class Decoder(nn.Module): - (h, c), containing the state information for LSTM layers. Both are of shape (num_layers, N, C) """ - embeding_out = self.embedding(y) - embeding_out = self.embedding_dropout(embeding_out) - rnn_out, (h, c) = self.rnn(embeding_out, states) + embedding_out = self.embedding(y) + embedding_out = self.embedding_dropout(embedding_out) + rnn_out, (h, c) = self.rnn(embedding_out, states) out = self.output_linear(rnn_out) return out, (h, c) diff --git a/egs/librispeech/ASR/transducer_lstm/decoder.py b/egs/librispeech/ASR/transducer_lstm/decoder.py index 2f6bf4c07..4d531bde1 100644 --- a/egs/librispeech/ASR/transducer_lstm/decoder.py +++ b/egs/librispeech/ASR/transducer_lstm/decoder.py @@ -93,9 +93,9 @@ class Decoder(nn.Module): - (h, c), containing the state information for LSTM layers. Both are of shape (num_layers, N, C) """ - embeding_out = self.embedding(y) - embeding_out = self.embedding_dropout(embeding_out) - rnn_out, (h, c) = self.rnn(embeding_out, states) + embedding_out = self.embedding(y) + embedding_out = self.embedding_dropout(embedding_out) + rnn_out, (h, c) = self.rnn(embedding_out, states) out = self.output_linear(rnn_out) return out, (h, c) diff --git a/egs/yesno/ASR/transducer/decoder.py b/egs/yesno/ASR/transducer/decoder.py index aa8a16845..7ae540d03 100644 --- a/egs/yesno/ASR/transducer/decoder.py +++ b/egs/yesno/ASR/transducer/decoder.py @@ -84,9 +84,9 @@ class Decoder(nn.Module): - (h, c), which contain the state information for RNN layers. Both are of shape (num_layers, N, C) """ - embeding_out = self.embedding(y) - embeding_out = self.embedding_dropout(embeding_out) - rnn_out, (h, c) = self.rnn(embeding_out, states) + embedding_out = self.embedding(y) + embedding_out = self.embedding_dropout(embedding_out) + rnn_out, (h, c) = self.rnn(embedding_out, states) out = self.output_linear(rnn_out) return out, (h, c) From 9a98e6ced6370e42f69a8d904ab66a481cfb4d6f Mon Sep 17 00:00:00 2001 From: pehonnet Date: Mon, 25 Apr 2022 12:51:53 +0200 Subject: [PATCH 2/3] fix fp16 option in example usage (#332) --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 80617847a..d15c44388 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -35,7 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ - --use_fp16 1 \ + --use-fp16 1 \ --exp-dir pruned_transducer_stateless2/exp \ --full-libri 1 \ --max-duration 550 From 9aeea3e1af2b288d0cd186a64a4ec2e37eccedc8 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 26 Apr 2022 13:32:03 +0800 Subject: [PATCH 3/3] Support averaging models with weight tying. (#333) --- icefall/checkpoint.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index cc167292b..a4e71a148 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -150,12 +150,25 @@ def average_checkpoints( n = len(filenames) avg = torch.load(filenames[0], map_location=device)["model"] + + # Identify shared parameters. Two parameters are said to be shared + # if they have the same data_ptr + uniqued: Dict[int, str] = dict() + + for k, v in avg.items(): + v_data_ptr = v.data_ptr() + if v_data_ptr in uniqued: + continue + uniqued[v_data_ptr] = k + + uniqued_names = list(uniqued.values()) + for i in range(1, n): state_dict = torch.load(filenames[i], map_location=device)["model"] - for k in avg: + for k in uniqued_names: avg[k] += state_dict[k] - for k in avg: + for k in uniqued_names: if avg[k].is_floating_point(): avg[k] /= n else: